diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..d17cb8558c38c80053b5bfb96ad877267b241972 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,35 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +*.pdparams filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..4e46b694de24b2437563731ab6b28501400768df --- /dev/null +++ b/.gitignore @@ -0,0 +1,194 @@ +.DS_Store + +/images + +# Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode,python +# Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode,python + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +### Python Patch ### +# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration +poetry.toml + + +### VisualStudioCode ### +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +!.vscode/*.code-snippets + +# Local History for Visual Studio Code +.history/ + +# Built Visual Studio Code Extensions +*.vsix + +### VisualStudioCode Patch ### +# Ignore all local history of files +.history +.ionide + +# End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,python diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..12be355bf4cd2dd457586f806e7088bf5ad5ec04 --- /dev/null +++ b/README.md @@ -0,0 +1,14 @@ +--- +title: Image Matting App +emoji: 🐨 +colorFrom: yellow +colorTo: gray +sdk: gradio +sdk_version: 3.11.0 +app_file: app.py +pinned: false +license: mit +duplicated_from: vivym/image-matting-app +--- + +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..e72eb8ea0ab7de196cc857b7166820f5b19a3d1d --- /dev/null +++ b/app.py @@ -0,0 +1,172 @@ +from hashlib import sha1 +from pathlib import Path + +import cv2 +import gradio as gr +import numpy as np +from PIL import Image + +from paddleseg.cvlibs import manager, Config +from paddleseg.utils import load_entire_model + +manager.BACKBONES._components_dict.clear() +manager.TRANSFORMS._components_dict.clear() + +import ppmatting as ppmatting +from ppmatting.core import predict +from ppmatting.utils import estimate_foreground_ml + +model_names = [ + "modnet-mobilenetv2", + "ppmatting-512", + "ppmatting-1024", + "ppmatting-2048", + "modnet-hrnet_w18", + "modnet-resnet50_vd", +] +model_dict = { + name: None + for name in model_names +} + +last_result = { + "cache_key": None, + "algorithm": None, +} + + +def image_matting( + image: np.ndarray, + result_type: str, + bg_color: str, + algorithm: str, + morph_op: str, + morph_op_factor: float, +) -> np.ndarray: + image = np.ascontiguousarray(image) + cache_key = sha1(image).hexdigest() + if cache_key == last_result["cache_key"] and algorithm == last_result["algorithm"]: + alpha = last_result["alpha"] + else: + cfg = Config(f"configs/{algorithm}.yml") + if model_dict[algorithm] is not None: + model = model_dict[algorithm] + else: + model = cfg.model + load_entire_model(model, f"models/{algorithm}.pdparams") + model.eval() + model_dict[algorithm] = model + + transforms = ppmatting.transforms.Compose(cfg.val_transforms) + + alpha = predict( + model, + transforms=transforms, + image=image, + ) + last_result["cache_key"] = cache_key + last_result["algorithm"] = algorithm + last_result["alpha"] = alpha + + alpha = (alpha * 255).astype(np.uint8) + kernel = np.ones((5, 5), np.uint8) + if morph_op == "Dilate": + alpha = cv2.dilate(alpha, kernel, iterations=int(morph_op_factor)) + else: + alpha = cv2.erode(alpha, kernel, iterations=int(morph_op_factor)) + alpha = (alpha / 255).astype(np.float32) + + image = (image / 255.0).astype("float32") + fg = estimate_foreground_ml(image, alpha) + + if result_type == "Remove BG": + result = np.concatenate((fg, alpha[:, :, None]), axis=-1) + elif result_type == "Replace BG": + bg_r = int(bg_color[1:3], base=16) + bg_g = int(bg_color[3:5], base=16) + bg_b = int(bg_color[5:7], base=16) + + bg = np.zeros_like(fg) + bg[:, :, 0] = bg_r / 255. + bg[:, :, 1] = bg_g / 255. + bg[:, :, 2] = bg_b / 255. + + result = alpha[:, :, None] * fg + (1 - alpha[:, :, None]) * bg + result = np.clip(result, 0, 1) + else: + result = alpha + + return result + + +def main(): + with gr.Blocks() as app: + gr.Markdown("Image Matting Powered By AI") + + with gr.Row(variant="panel"): + image_input = gr.Image() + image_output = gr.Image() + + with gr.Row(variant="panel"): + result_type = gr.Radio( + label="Mode", + show_label=True, + choices=[ + "Remove BG", + "Replace BG", + "Generate Mask", + ], + value="Remove BG", + ) + bg_color = gr.ColorPicker( + label="BG Color", + show_label=True, + value="#000000", + ) + algorithm = gr.Dropdown( + label="Algorithm", + show_label=True, + choices=model_names, + value="modnet-hrnet_w18" + ) + + with gr.Row(variant="panel"): + morph_op = gr.Radio( + label="Post-process", + show_label=True, + choices=[ + "Dilate", + "Erode", + ], + value="Dilate", + ) + + morph_op_factor = gr.Slider( + label="Factor", + show_label=True, + minimum=0, + maximum=20, + value=0, + step=1, + ) + + run_button = gr.Button("Run") + + run_button.click( + image_matting, + inputs=[ + image_input, + result_type, + bg_color, + algorithm, + morph_op, + morph_op_factor, + ], + outputs=image_output, + ) + + app.launch() + + +if __name__ == "__main__": + main() diff --git a/configs/modnet-hrnet_w18.yml b/configs/modnet-hrnet_w18.yml new file mode 100644 index 0000000000000000000000000000000000000000..c0c1e00c7c424c3bfa6a6c03ae1f6e9e949b36e9 --- /dev/null +++ b/configs/modnet-hrnet_w18.yml @@ -0,0 +1,5 @@ +_base_: modnet-mobilenetv2.yml +model: + backbone: + type: HRNet_W18 + # pretrained: https://bj.bcebos.com/paddleseg/dygraph/hrnet_w18_ssld.tar.gz diff --git a/configs/modnet-mobilenetv2.yml b/configs/modnet-mobilenetv2.yml new file mode 100644 index 0000000000000000000000000000000000000000..80dfd04281e8674fdc35e119ac1a126f844d04c2 --- /dev/null +++ b/configs/modnet-mobilenetv2.yml @@ -0,0 +1,47 @@ +batch_size: 16 +iters: 100000 + +train_dataset: + type: MattingDataset + dataset_root: data/PPM-100 + train_file: train.txt + transforms: + - type: LoadImages + - type: RandomCrop + crop_size: [512, 512] + - type: RandomDistort + - type: RandomBlur + - type: RandomHorizontalFlip + - type: Normalize + mode: train + +val_dataset: + type: MattingDataset + dataset_root: data/PPM-100 + val_file: val.txt + transforms: + - type: LoadImages + - type: ResizeByShort + short_size: 512 + - type: ResizeToIntMult + mult_int: 32 + - type: Normalize + mode: val + get_trimap: False + +model: + type: MODNet + backbone: + type: MobileNetV2 + # pretrained: https://paddleseg.bj.bcebos.com/matting/models/MobileNetV2_pretrained/model.pdparams + pretrained: Null + +optimizer: + type: sgd + momentum: 0.9 + weight_decay: 4.0e-5 + +lr_scheduler: + type: PiecewiseDecay + boundaries: [40000, 80000] + values: [0.02, 0.002, 0.0002] diff --git a/configs/modnet-resnet50_vd.yml b/configs/modnet-resnet50_vd.yml new file mode 100644 index 0000000000000000000000000000000000000000..e1e2104049412d5a4f2c164540ed1e7b7f993efc --- /dev/null +++ b/configs/modnet-resnet50_vd.yml @@ -0,0 +1,5 @@ +_base_: modnet-mobilenetv2.yml +model: + backbone: + type: ResNet50_vd + # pretrained: https://bj.bcebos.com/paddleseg/dygraph/resnet50_vd_ssld_v2.tar.gz diff --git a/configs/ppmatting-1024.yml b/configs/ppmatting-1024.yml new file mode 100644 index 0000000000000000000000000000000000000000..a7ece05efda852ecee950fe37541751d5a15dc90 --- /dev/null +++ b/configs/ppmatting-1024.yml @@ -0,0 +1,29 @@ +_base_: 'ppmatting-hrnet_w18-human_512.yml' + + +train_dataset: + transforms: + - type: LoadImages + - type: LimitShort + max_short: 1024 + - type: RandomCrop + crop_size: [1024, 1024] + - type: RandomDistort + - type: RandomBlur + prob: 0.1 + - type: RandomNoise + prob: 0.5 + - type: RandomReJpeg + prob: 0.2 + - type: RandomHorizontalFlip + - type: Normalize + +val_dataset: + transforms: + - type: LoadImages + - type: LimitShort + max_short: 1024 + - type: ResizeToIntMult + mult_int: 32 + - type: Normalize + diff --git a/configs/ppmatting-2048.yml b/configs/ppmatting-2048.yml new file mode 100644 index 0000000000000000000000000000000000000000..e8dc2ddbd90139d39a2137bf2086611c6d6c87fe --- /dev/null +++ b/configs/ppmatting-2048.yml @@ -0,0 +1,54 @@ +batch_size: 4 +iters: 50000 + +train_dataset: + type: MattingDataset + dataset_root: data/PPM-100 + train_file: train.txt + transforms: + - type: LoadImages + - type: RandomResize + size: [2048, 2048] + scale: [0.3, 1.5] + - type: RandomCrop + crop_size: [2048, 2048] + - type: RandomDistort + - type: RandomBlur + prob: 0.1 + - type: RandomHorizontalFlip + - type: Padding + target_size: [2048, 2048] + - type: Normalize + mode: train + +val_dataset: + type: MattingDataset + dataset_root: data/PPM-100 + val_file: val.txt + transforms: + - type: LoadImages + - type: ResizeByShort + short_size: 2048 + - type: ResizeToIntMult + mult_int: 128 + - type: Normalize + mode: val + get_trimap: False + +model: + type: HumanMatting + backbone: + type: ResNet34_vd + # pretrained: https://paddleseg.bj.bcebos.com/matting/models/ResNet34_vd_pretrained/model.pdparams + pretrained: Null + if_refine: True + +optimizer: + type: sgd + momentum: 0.9 + weight_decay: 4.0e-5 + +lr_scheduler: + type: PiecewiseDecay + boundaries: [30000, 40000] + values: [0.001, 0.0001, 0.00001] diff --git a/configs/ppmatting-512.yml b/configs/ppmatting-512.yml new file mode 100644 index 0000000000000000000000000000000000000000..e6f9e1cabfe1096a2f49fc26636eb88b56047c2c --- /dev/null +++ b/configs/ppmatting-512.yml @@ -0,0 +1,44 @@ +_base_: 'ppmatting-hrnet_w48-distinctions.yml' + +batch_size: 4 +iters: 200000 + +train_dataset: + type: MattingDataset + dataset_root: data/PPM-100 + train_file: train.txt + transforms: + - type: LoadImages + - type: LimitShort + max_short: 512 + - type: RandomCrop + crop_size: [512, 512] + - type: RandomDistort + - type: RandomBlur + prob: 0.1 + - type: RandomNoise + prob: 0.5 + - type: RandomReJpeg + prob: 0.2 + - type: RandomHorizontalFlip + - type: Normalize + mode: train + +val_dataset: + type: MattingDataset + dataset_root: data/PPM-100 + val_file: val.txt + transforms: + - type: LoadImages + - type: LimitShort + max_short: 512 + - type: ResizeToIntMult + mult_int: 32 + - type: Normalize + mode: val + get_trimap: False + +model: + backbone: + type: HRNet_W18 + # pretrained: https://bj.bcebos.com/paddleseg/dygraph/hrnet_w18_ssld.tar.gz diff --git a/configs/ppmatting-hrnet_w48-composition.yml b/configs/ppmatting-hrnet_w48-composition.yml new file mode 100644 index 0000000000000000000000000000000000000000..64f011587223fa635d6368199d9ab237a3266e08 --- /dev/null +++ b/configs/ppmatting-hrnet_w48-composition.yml @@ -0,0 +1,7 @@ +_base_: 'ppmatting-hrnet_w48-distinctions.yml' + +train_dataset: + dataset_root: data/matting/Composition-1k + +val_dataset: + dataset_root: data/matting/Composition-1k \ No newline at end of file diff --git a/configs/ppmatting-hrnet_w48-distinctions.yml b/configs/ppmatting-hrnet_w48-distinctions.yml new file mode 100644 index 0000000000000000000000000000000000000000..9adb05955bad54eaf19eeb43c92ff2dfccf14eb6 --- /dev/null +++ b/configs/ppmatting-hrnet_w48-distinctions.yml @@ -0,0 +1,55 @@ +batch_size: 4 +iters: 300000 + +train_dataset: + type: MattingDataset + dataset_root: data/matting/Distinctions-646 + train_file: train.txt + transforms: + - type: LoadImages + - type: Padding + target_size: [512, 512] + - type: RandomCrop + crop_size: [[512, 512],[640, 640], [800, 800]] + - type: Resize + target_size: [512, 512] + - type: RandomDistort + - type: RandomBlur + prob: 0.1 + - type: RandomHorizontalFlip + - type: Normalize + mode: train + separator: '|' + +val_dataset: + type: MattingDataset + dataset_root: data/matting/Distinctions-646 + val_file: val.txt + transforms: + - type: LoadImages + - type: LimitShort + max_short: 1536 + - type: ResizeToIntMult + mult_int: 32 + - type: Normalize + mode: val + get_trimap: False + separator: '|' + +model: + type: PPMatting + backbone: + type: HRNet_W48 + # pretrained: https://bj.bcebos.com/paddleseg/dygraph/hrnet_w48_ssld.tar.gz + pretrained: Null + +optimizer: + type: sgd + momentum: 0.9 + weight_decay: 4.0e-5 + +lr_scheduler: + type: PolynomialDecay + learning_rate: 0.01 + end_lr: 0 + power: 0.9 \ No newline at end of file diff --git a/models/modnet-hrnet_w18.pdparams b/models/modnet-hrnet_w18.pdparams new file mode 100644 index 0000000000000000000000000000000000000000..34213499a0c3d757aa5e200368f140b54f08a279 --- /dev/null +++ b/models/modnet-hrnet_w18.pdparams @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:02863c8069c11367cdd7d25469ed66d133e9b835fee1f6adc76086eb33c83ac8 +size 41174502 diff --git a/models/modnet-mobilenetv2.pdparams b/models/modnet-mobilenetv2.pdparams new file mode 100644 index 0000000000000000000000000000000000000000..0d5ed93d13232e5ee766620fb4b8f07e20b2dfc8 --- /dev/null +++ b/models/modnet-mobilenetv2.pdparams @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3dbec3ca48dae927354efabd5b9d35e9af9998caf91e544c606af8589ad0528a +size 26143420 diff --git a/models/modnet-resnet50_vd.pdparams b/models/modnet-resnet50_vd.pdparams new file mode 100644 index 0000000000000000000000000000000000000000..9e554af87ba7562953a598de5efd4482a36b09dc --- /dev/null +++ b/models/modnet-resnet50_vd.pdparams @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:77568bbe3120b2490b1167df7c137402b3d3513617f5e69306e7bafd3d9f525e +size 368802825 diff --git a/models/ppmatting-1024.pdparams b/models/ppmatting-1024.pdparams new file mode 100644 index 0000000000000000000000000000000000000000..2481fc522677a0b7970c96da51f7d3a5537cf2f9 --- /dev/null +++ b/models/ppmatting-1024.pdparams @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f8e8db1e8b20a62f24b5ffae4f7e8f4a89d0db11169647967fecf2c3d17c0f99 +size 98439023 diff --git a/models/ppmatting-2048.pdparams b/models/ppmatting-2048.pdparams new file mode 100644 index 0000000000000000000000000000000000000000..d4820a64656d4d130df5c85bafdfea16c0d6d77c --- /dev/null +++ b/models/ppmatting-2048.pdparams @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7b13d1a1284d61d087cb0dd5e1d02178754053f7a03fd456484c77719b2e3a97 +size 255754333 diff --git a/models/ppmatting-512.pdparams b/models/ppmatting-512.pdparams new file mode 100644 index 0000000000000000000000000000000000000000..d286008517e189aa899f184f62c5f9c13f98ba95 --- /dev/null +++ b/models/ppmatting-512.pdparams @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0121d1494a4bad8a620e07a935b7bd97374f121ca0f48ba96b56df2972b0e054 +size 98439023 diff --git a/ppmatting/__init__.py b/ppmatting/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c1094808e27aa683fc3b5766e9968712b3021532 --- /dev/null +++ b/ppmatting/__init__.py @@ -0,0 +1 @@ +from . import ml, metrics, transforms, datasets, models diff --git a/ppmatting/core/__init__.py b/ppmatting/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..78060ba48aac1fd7d8cb32eccc7ccddadd74017f --- /dev/null +++ b/ppmatting/core/__init__.py @@ -0,0 +1,4 @@ +from .val import evaluate +from .val_ml import evaluate_ml +from .train import train +from .predict import predict \ No newline at end of file diff --git a/ppmatting/core/predict.py b/ppmatting/core/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..e7ff765d9c62f3cb7b758d1756632cfe65cab0f1 --- /dev/null +++ b/ppmatting/core/predict.py @@ -0,0 +1,58 @@ +from typing import Optional + +import numpy as np +import paddle +import paddle.nn.functional as F + + +def reverse_transform(alpha, trans_info): + """recover pred to origin shape""" + for item in trans_info[::-1]: + if item[0] == "resize": + h, w = item[1][0], item[1][1] + alpha = F.interpolate(alpha, [h, w], mode="bilinear") + elif item[0] == "padding": + h, w = item[1][0], item[1][1] + alpha = alpha[:, :, 0:h, 0:w] + else: + raise Exception(f"Unexpected info '{item[0]}' in im_info") + + return alpha + + +def preprocess(img, transforms, trimap=None): + data = {} + data["img"] = img + if trimap is not None: + data["trimap"] = trimap + data["gt_fields"] = ["trimap"] + data["trans_info"] = [] + data = transforms(data) + data["img"] = paddle.to_tensor(data["img"]) + data["img"] = data["img"].unsqueeze(0) + if trimap is not None: + data["trimap"] = paddle.to_tensor(data["trimap"]) + data["trimap"] = data["trimap"].unsqueeze((0, 1)) + + return data + + +def predict( + model, + transforms, + image: np.ndarray, + trimap: Optional[np.ndarray] = None, +): + with paddle.no_grad(): + data = preprocess(img=image, transforms=transforms, trimap=None) + + alpha = model(data) + + alpha = reverse_transform(alpha, data["trans_info"]) + alpha = alpha.numpy().squeeze() + + if trimap is not None: + alpha[trimap == 0] = 0 + alpha[trimap == 255] = 1. + + return alpha diff --git a/ppmatting/core/train.py b/ppmatting/core/train.py new file mode 100644 index 0000000000000000000000000000000000000000..695a177dcdd13fc7e79cf067a5c6984f5f125904 --- /dev/null +++ b/ppmatting/core/train.py @@ -0,0 +1,315 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +from collections import deque, defaultdict +import pickle +import shutil + +import numpy as np +import paddle +import paddle.nn.functional as F +from paddleseg.utils import TimeAverager, calculate_eta, resume, logger + +from .val import evaluate + + +def visual_in_traning(log_writer, vis_dict, step): + """ + Visual in vdl + + Args: + log_writer (LogWriter): The log writer of vdl. + vis_dict (dict): Dict of tensor. The shape of thesor is (C, H, W) + """ + for key, value in vis_dict.items(): + value_shape = value.shape + if value_shape[0] not in [1, 3]: + value = value[0] + value = value.unsqueeze(0) + value = paddle.transpose(value, (1, 2, 0)) + min_v = paddle.min(value) + max_v = paddle.max(value) + if (min_v > 0) and (max_v < 1): + value = value * 255 + elif (min_v < 0 and min_v >= -1) and (max_v <= 1): + value = (1 + value) / 2 * 255 + else: + value = (value - min_v) / (max_v - min_v) * 255 + + value = value.astype('uint8') + value = value.numpy() + log_writer.add_image(tag=key, img=value, step=step) + + +def save_best(best_model_dir, metrics_data, iter): + with open(os.path.join(best_model_dir, 'best_metrics.txt'), 'w') as f: + for key, value in metrics_data.items(): + line = key + ' ' + str(value) + '\n' + f.write(line) + f.write('iter' + ' ' + str(iter) + '\n') + + +def get_best(best_file, metrics, resume_model=None): + '''Get best metrics and iter from file''' + best_metrics_data = {} + if os.path.exists(best_file) and (resume_model is not None): + values = [] + with open(best_file, 'r') as f: + lines = f.readlines() + for line in lines: + line = line.strip() + key, value = line.split(' ') + best_metrics_data[key] = eval(value) + if key == 'iter': + best_iter = eval(value) + else: + for key in metrics: + best_metrics_data[key] = np.inf + best_iter = -1 + return best_metrics_data, best_iter + + +def train(model, + train_dataset, + val_dataset=None, + optimizer=None, + save_dir='output', + iters=10000, + batch_size=2, + resume_model=None, + save_interval=1000, + log_iters=10, + log_image_iters=1000, + num_workers=0, + use_vdl=False, + losses=None, + keep_checkpoint_max=5, + eval_begin_iters=None, + metrics='sad'): + """ + Launch training. + Args: + model(nn.Layer): A matting model. + train_dataset (paddle.io.Dataset): Used to read and process training datasets. + val_dataset (paddle.io.Dataset, optional): Used to read and process validation datasets. + optimizer (paddle.optimizer.Optimizer): The optimizer. + save_dir (str, optional): The directory for saving the model snapshot. Default: 'output'. + iters (int, optional): How may iters to train the model. Defualt: 10000. + batch_size (int, optional): Mini batch size of one gpu or cpu. Default: 2. + resume_model (str, optional): The path of resume model. + save_interval (int, optional): How many iters to save a model snapshot once during training. Default: 1000. + log_iters (int, optional): Display logging information at every log_iters. Default: 10. + log_image_iters (int, optional): Log image to vdl. Default: 1000. + num_workers (int, optional): Num workers for data loader. Default: 0. + use_vdl (bool, optional): Whether to record the data to VisualDL during training. Default: False. + losses (dict, optional): A dict of loss, refer to the loss function of the model for details. Default: None. + keep_checkpoint_max (int, optional): Maximum number of checkpoints to save. Default: 5. + eval_begin_iters (int): The iters begin evaluation. It will evaluate at iters/2 if it is None. Defalust: None. + metrics(str|list, optional): The metrics to evaluate, it may be the combination of ("sad", "mse", "grad", "conn"). + """ + model.train() + nranks = paddle.distributed.ParallelEnv().nranks + local_rank = paddle.distributed.ParallelEnv().local_rank + + start_iter = 0 + if resume_model is not None: + start_iter = resume(model, optimizer, resume_model) + + if not os.path.isdir(save_dir): + if os.path.exists(save_dir): + os.remove(save_dir) + os.makedirs(save_dir) + + if nranks > 1: + # Initialize parallel environment if not done. + if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized( + ): + paddle.distributed.init_parallel_env() + ddp_model = paddle.DataParallel(model) + else: + ddp_model = paddle.DataParallel(model) + + batch_sampler = paddle.io.DistributedBatchSampler( + train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) + + loader = paddle.io.DataLoader( + train_dataset, + batch_sampler=batch_sampler, + num_workers=num_workers, + return_list=True, ) + + if use_vdl: + from visualdl import LogWriter + log_writer = LogWriter(save_dir) + + if isinstance(metrics, str): + metrics = [metrics] + elif not isinstance(metrics, list): + metrics = ['sad'] + best_metrics_data, best_iter = get_best( + os.path.join(save_dir, 'best_model', 'best_metrics.txt'), + metrics, + resume_model=resume_model) + avg_loss = defaultdict(float) + iters_per_epoch = len(batch_sampler) + reader_cost_averager = TimeAverager() + batch_cost_averager = TimeAverager() + save_models = deque() + batch_start = time.time() + + iter = start_iter + while iter < iters: + for data in loader: + iter += 1 + if iter > iters: + break + reader_cost_averager.record(time.time() - batch_start) + + logit_dict, loss_dict = ddp_model(data) if nranks > 1 else model( + data) + + loss_dict['all'].backward() + + optimizer.step() + lr = optimizer.get_lr() + if isinstance(optimizer._learning_rate, + paddle.optimizer.lr.LRScheduler): + optimizer._learning_rate.step() + model.clear_gradients() + + for key, value in loss_dict.items(): + avg_loss[key] += value.numpy()[0] + batch_cost_averager.record( + time.time() - batch_start, num_samples=batch_size) + + if (iter) % log_iters == 0 and local_rank == 0: + for key, value in avg_loss.items(): + avg_loss[key] = value / log_iters + remain_iters = iters - iter + avg_train_batch_cost = batch_cost_averager.get_average() + avg_train_reader_cost = reader_cost_averager.get_average() + eta = calculate_eta(remain_iters, avg_train_batch_cost) + # loss info + loss_str = ' ' * 26 + '\t[LOSSES]' + loss_str = loss_str + for key, value in avg_loss.items(): + if key != 'all': + loss_str = loss_str + ' ' + key + '={:.4f}'.format( + value) + logger.info( + "[TRAIN] epoch={}, iter={}/{}, loss={:.4f}, lr={:.6f}, batch_cost={:.4f}, reader_cost={:.5f}, ips={:.4f} samples/sec | ETA {}\n{}\n" + .format((iter - 1) // iters_per_epoch + 1, iter, iters, + avg_loss['all'], lr, avg_train_batch_cost, + avg_train_reader_cost, + batch_cost_averager.get_ips_average( + ), eta, loss_str)) + if use_vdl: + for key, value in avg_loss.items(): + log_tag = 'Train/' + key + log_writer.add_scalar(log_tag, value, iter) + + log_writer.add_scalar('Train/lr', lr, iter) + log_writer.add_scalar('Train/batch_cost', + avg_train_batch_cost, iter) + log_writer.add_scalar('Train/reader_cost', + avg_train_reader_cost, iter) + if iter % log_image_iters == 0: + vis_dict = {} + # ground truth + vis_dict['ground truth/img'] = data['img'][0] + for key in data['gt_fields']: + key = key[0] + vis_dict['/'.join(['ground truth', key])] = data[ + key][0] + # predict + for key, value in logit_dict.items(): + vis_dict['/'.join(['predict', key])] = logit_dict[ + key][0] + visual_in_traning( + log_writer=log_writer, vis_dict=vis_dict, step=iter) + + for key in avg_loss.keys(): + avg_loss[key] = 0. + reader_cost_averager.reset() + batch_cost_averager.reset() + + # save model + if (iter % save_interval == 0 or iter == iters) and local_rank == 0: + current_save_dir = os.path.join(save_dir, + "iter_{}".format(iter)) + if not os.path.isdir(current_save_dir): + os.makedirs(current_save_dir) + paddle.save(model.state_dict(), + os.path.join(current_save_dir, 'model.pdparams')) + paddle.save(optimizer.state_dict(), + os.path.join(current_save_dir, 'model.pdopt')) + save_models.append(current_save_dir) + if len(save_models) > keep_checkpoint_max > 0: + model_to_remove = save_models.popleft() + shutil.rmtree(model_to_remove) + + # eval model + if eval_begin_iters is None: + eval_begin_iters = iters // 2 + if (iter % save_interval == 0 or iter == iters) and ( + val_dataset is not None + ) and local_rank == 0 and iter >= eval_begin_iters: + num_workers = 1 if num_workers > 0 else 0 + metrics_data = evaluate( + model, + val_dataset, + num_workers=1, + print_detail=True, + save_results=False, + metrics=metrics) + model.train() + + # save best model and add evaluation results to vdl + if (iter % save_interval == 0 or iter == iters) and local_rank == 0: + if val_dataset is not None and iter >= eval_begin_iters: + if metrics_data[metrics[0]] < best_metrics_data[metrics[0]]: + best_iter = iter + best_metrics_data = metrics_data.copy() + best_model_dir = os.path.join(save_dir, "best_model") + paddle.save( + model.state_dict(), + os.path.join(best_model_dir, 'model.pdparams')) + save_best(best_model_dir, best_metrics_data, iter) + + show_list = [] + for key, value in best_metrics_data.items(): + show_list.append((key, value)) + log_str = '[EVAL] The model with the best validation {} ({:.4f}) was saved at iter {}.'.format( + show_list[0][0], show_list[0][1], best_iter) + if len(show_list) > 1: + log_str += " While" + for i in range(1, len(show_list)): + log_str = log_str + ' {}: {:.4f},'.format( + show_list[i][0], show_list[i][1]) + log_str = log_str[:-1] + logger.info(log_str) + + if use_vdl: + for key, value in metrics_data.items(): + log_writer.add_scalar('Evaluate/' + key, value, + iter) + + batch_start = time.time() + + # Sleep for half a second to let dataloader release resources. + time.sleep(0.5) + if use_vdl: + log_writer.close() diff --git a/ppmatting/core/val.py b/ppmatting/core/val.py new file mode 100644 index 0000000000000000000000000000000000000000..3e3117725ab3792fc7a2344082ad45f26cb2cd28 --- /dev/null +++ b/ppmatting/core/val.py @@ -0,0 +1,162 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import cv2 +import numpy as np +import time +import paddle +import paddle.nn.functional as F +from paddleseg.utils import TimeAverager, calculate_eta, logger, progbar + +from ppmatting.metrics import metrics_class_dict + +np.set_printoptions(suppress=True) + + +def save_alpha_pred(alpha, path): + """ + The value of alpha is range [0, 1], shape should be [h,w] + """ + dirname = os.path.dirname(path) + if not os.path.exists(dirname): + os.makedirs(dirname) + + alpha = (alpha).astype('uint8') + cv2.imwrite(path, alpha) + + +def reverse_transform(alpha, trans_info): + """recover pred to origin shape""" + for item in trans_info[::-1]: + if item[0][0] == 'resize': + h, w = item[1][0], item[1][1] + alpha = F.interpolate(alpha, [h, w], mode='bilinear') + elif item[0][0] == 'padding': + h, w = item[1][0], item[1][1] + alpha = alpha[:, :, 0:h, 0:w] + else: + raise Exception("Unexpected info '{}' in im_info".format(item[0])) + return alpha + + +def evaluate(model, + eval_dataset, + num_workers=0, + print_detail=True, + save_dir='output/results', + save_results=True, + metrics='sad'): + model.eval() + nranks = paddle.distributed.ParallelEnv().nranks + local_rank = paddle.distributed.ParallelEnv().local_rank + if nranks > 1: + # Initialize parallel environment if not done. + if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized( + ): + paddle.distributed.init_parallel_env() + + loader = paddle.io.DataLoader( + eval_dataset, + batch_size=1, + drop_last=False, + num_workers=num_workers, + return_list=True, ) + + total_iters = len(loader) + # Get metric instances and data saving + metrics_ins = {} + metrics_data = {} + if isinstance(metrics, str): + metrics = [metrics] + elif not isinstance(metrics, list): + metrics = ['sad'] + for key in metrics: + key = key.lower() + metrics_ins[key] = metrics_class_dict[key]() + metrics_data[key] = None + + if print_detail: + logger.info("Start evaluating (total_samples: {}, total_iters: {})...". + format(len(eval_dataset), total_iters)) + progbar_val = progbar.Progbar( + target=total_iters, verbose=1 if nranks < 2 else 2) + reader_cost_averager = TimeAverager() + batch_cost_averager = TimeAverager() + batch_start = time.time() + + img_name = '' + i = 0 + with paddle.no_grad(): + for iter, data in enumerate(loader): + reader_cost_averager.record(time.time() - batch_start) + alpha_pred = model(data) + + alpha_pred = reverse_transform(alpha_pred, data['trans_info']) + alpha_pred = alpha_pred.numpy() + + alpha_gt = data['alpha'].numpy() * 255 + trimap = data.get('ori_trimap') + if trimap is not None: + trimap = trimap.numpy().astype('uint8') + alpha_pred = np.round(alpha_pred * 255) + for key in metrics_ins.keys(): + metrics_data[key] = metrics_ins[key].update(alpha_pred, + alpha_gt, trimap) + + if save_results: + alpha_pred_one = alpha_pred[0].squeeze() + if trimap is not None: + trimap = trimap.squeeze().astype('uint8') + alpha_pred_one[trimap == 255] = 255 + alpha_pred_one[trimap == 0] = 0 + + save_name = data['img_name'][0] + name, ext = os.path.splitext(save_name) + if save_name == img_name: + save_name = name + '_' + str(i) + ext + i += 1 + else: + img_name = save_name + save_name = name + '_' + str(i) + ext + i = 1 + + save_alpha_pred(alpha_pred_one, + os.path.join(save_dir, save_name)) + + batch_cost_averager.record( + time.time() - batch_start, num_samples=len(alpha_gt)) + batch_cost = batch_cost_averager.get_average() + reader_cost = reader_cost_averager.get_average() + + if local_rank == 0 and print_detail: + show_list = [(k, v) for k, v in metrics_data.items()] + show_list = show_list + [('batch_cost', batch_cost), + ('reader cost', reader_cost)] + progbar_val.update(iter + 1, show_list) + + reader_cost_averager.reset() + batch_cost_averager.reset() + batch_start = time.time() + + for key in metrics_ins.keys(): + metrics_data[key] = metrics_ins[key].evaluate() + log_str = '[EVAL] ' + for key, value in metrics_data.items(): + log_str = log_str + key + ': {:.4f}, '.format(value) + log_str = log_str[:-2] + + logger.info(log_str) + return metrics_data diff --git a/ppmatting/core/val_ml.py b/ppmatting/core/val_ml.py new file mode 100644 index 0000000000000000000000000000000000000000..77628925bec1fa08a4a24de685355cc71157db92 --- /dev/null +++ b/ppmatting/core/val_ml.py @@ -0,0 +1,162 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import cv2 +import numpy as np +import time +import paddle +import paddle.nn.functional as F +from paddleseg.utils import TimeAverager, calculate_eta, logger, progbar + +from ppmatting.metrics import metric +from pymatting.util.util import load_image, save_image, stack_images +from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml + +np.set_printoptions(suppress=True) + + +def save_alpha_pred(alpha, path): + """ + The value of alpha is range [0, 1], shape should be [h,w] + """ + dirname = os.path.dirname(path) + if not os.path.exists(dirname): + os.makedirs(dirname) + + alpha = (alpha).astype('uint8') + cv2.imwrite(path, alpha) + + +def reverse_transform(alpha, trans_info): + """recover pred to origin shape""" + for item in trans_info[::-1]: + if item[0][0] == 'resize': + h, w = item[1][0].numpy()[0], item[1][1].numpy()[0] + alpha = cv2.resize(alpha, dsize=(w, h)) + elif item[0][0] == 'padding': + h, w = item[1][0].numpy()[0], item[1][1].numpy()[0] + alpha = alpha[0:h, 0:w] + else: + raise Exception("Unexpected info '{}' in im_info".format(item[0])) + return alpha + + +def evaluate_ml(model, + eval_dataset, + num_workers=0, + print_detail=True, + save_dir='output/results', + save_results=True): + + loader = paddle.io.DataLoader( + eval_dataset, + batch_size=1, + drop_last=False, + num_workers=num_workers, + return_list=True, ) + + total_iters = len(loader) + mse_metric = metric.MSE() + sad_metric = metric.SAD() + grad_metric = metric.Grad() + conn_metric = metric.Conn() + + if print_detail: + logger.info("Start evaluating (total_samples: {}, total_iters: {})...". + format(len(eval_dataset), total_iters)) + progbar_val = progbar.Progbar(target=total_iters, verbose=1) + reader_cost_averager = TimeAverager() + batch_cost_averager = TimeAverager() + batch_start = time.time() + + img_name = '' + i = 0 + ignore_cnt = 0 + for iter, data in enumerate(loader): + + reader_cost_averager.record(time.time() - batch_start) + + image_rgb_chw = data['img'].numpy()[0] + image_rgb_hwc = np.transpose(image_rgb_chw, (1, 2, 0)) + trimap = data['trimap'].numpy().squeeze() / 255.0 + image = image_rgb_hwc * 0.5 + 0.5 # reverse normalize (x/255 - mean) / std + + is_fg = trimap >= 0.9 + is_bg = trimap <= 0.1 + + if is_fg.sum() == 0 or is_bg.sum() == 0: + ignore_cnt += 1 + logger.info(str(iter)) + continue + + alpha_pred = model(image, trimap) + + alpha_pred = reverse_transform(alpha_pred, data['trans_info']) + + alpha_gt = data['alpha'].numpy().squeeze() * 255 + + trimap = data['ori_trimap'].numpy().squeeze() + + alpha_pred = np.round(alpha_pred * 255) + mse = mse_metric.update(alpha_pred, alpha_gt, trimap) + sad = sad_metric.update(alpha_pred, alpha_gt, trimap) + grad = grad_metric.update(alpha_pred, alpha_gt, trimap) + conn = conn_metric.update(alpha_pred, alpha_gt, trimap) + + if sad > 1000: + print(data['img_name'][0]) + + if save_results: + alpha_pred_one = alpha_pred + alpha_pred_one[trimap == 255] = 255 + alpha_pred_one[trimap == 0] = 0 + + save_name = data['img_name'][0] + name, ext = os.path.splitext(save_name) + if save_name == img_name: + save_name = name + '_' + str(i) + ext + i += 1 + else: + img_name = save_name + save_name = name + '_' + str(0) + ext + i = 1 + save_alpha_pred(alpha_pred_one, os.path.join(save_dir, save_name)) + + batch_cost_averager.record( + time.time() - batch_start, num_samples=len(alpha_gt)) + batch_cost = batch_cost_averager.get_average() + reader_cost = reader_cost_averager.get_average() + + if print_detail: + progbar_val.update(iter + 1, + [('SAD', sad), ('MSE', mse), ('Grad', grad), + ('Conn', conn), ('batch_cost', batch_cost), + ('reader cost', reader_cost)]) + + reader_cost_averager.reset() + batch_cost_averager.reset() + batch_start = time.time() + + mse = mse_metric.evaluate() + sad = sad_metric.evaluate() + grad = grad_metric.evaluate() + conn = conn_metric.evaluate() + + logger.info('[EVAL] SAD: {:.4f}, MSE: {:.4f}, Grad: {:.4f}, Conn: {:.4f}'. + format(sad, mse, grad, conn)) + logger.info('{}'.format(ignore_cnt)) + + return sad, mse, grad, conn diff --git a/ppmatting/datasets/__init__.py b/ppmatting/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..55febcaefed2e14676cbb0864f8d4cc4c1ef7459 --- /dev/null +++ b/ppmatting/datasets/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .matting_dataset import MattingDataset +from .composition_1k import Composition1K +from .distinctions_646 import Distinctions646 diff --git a/ppmatting/datasets/composition_1k.py b/ppmatting/datasets/composition_1k.py new file mode 100644 index 0000000000000000000000000000000000000000..854b29bed6d91f20616060c3cee50fc21dc5b8f2 --- /dev/null +++ b/ppmatting/datasets/composition_1k.py @@ -0,0 +1,31 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import math + +import cv2 +import numpy as np +import random +import paddle +from paddleseg.cvlibs import manager + +import ppmatting.transforms as T +from ppmatting.datasets.matting_dataset import MattingDataset + + +@manager.DATASETS.add_component +class Composition1K(MattingDataset): + def __init__(self, **kwargs): + super().__init__(**kwargs) diff --git a/ppmatting/datasets/distinctions_646.py b/ppmatting/datasets/distinctions_646.py new file mode 100644 index 0000000000000000000000000000000000000000..d20b08f2e6b2583ef03bfdc2c30e84fcefd02607 --- /dev/null +++ b/ppmatting/datasets/distinctions_646.py @@ -0,0 +1,31 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import math + +import cv2 +import numpy as np +import random +import paddle +from paddleseg.cvlibs import manager + +import ppmatting.transforms as T +from ppmatting.datasets.matting_dataset import MattingDataset + + +@manager.DATASETS.add_component +class Distinctions646(MattingDataset): + def __init__(self, **kwargs): + super().__init__(**kwargs) diff --git a/ppmatting/datasets/matting_dataset.py b/ppmatting/datasets/matting_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d782d6c35acbb583f8fbdc61685e222ff0437996 --- /dev/null +++ b/ppmatting/datasets/matting_dataset.py @@ -0,0 +1,251 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import math + +import cv2 +import numpy as np +import random +import paddle +from paddleseg.cvlibs import manager + +import ppmatting.transforms as T + + +@manager.DATASETS.add_component +class MattingDataset(paddle.io.Dataset): + """ + Pass in a dataset that conforms to the format. + matting_dataset/ + |--bg/ + | + |--train/ + | |--fg/ + | |--alpha/ + | + |--val/ + | |--fg/ + | |--alpha/ + | |--trimap/ (if existing) + | + |--train.txt + | + |--val.txt + See README.md for more information of dataset. + + Args: + dataset_root(str): The root path of dataset. + transforms(list): Transforms for image. + mode (str, optional): which part of dataset to use. it is one of ('train', 'val', 'trainval'). Default: 'train'. + train_file (str|list, optional): File list is used to train. It should be `foreground_image.png background_image.png` + or `foreground_image.png`. It shold be provided if mode equal to 'train'. Default: None. + val_file (str|list, optional): File list is used to evaluation. It should be `foreground_image.png background_image.png` + or `foreground_image.png` or ``foreground_image.png background_image.png trimap_image.png`. + It shold be provided if mode equal to 'val'. Default: None. + get_trimap (bool, optional): Whether to get triamp. Default: True. + separator (str, optional): The separator of train_file or val_file. If file name contains ' ', '|' may be perfect. Default: ' '. + key_del (tuple|list, optional): The key which is not need will be delete to accellect data reader. Default: None. + if_rssn (bool, optional): Whether to use RSSN while Compositing image. Including denoise and blur. Default: False. + """ + + def __init__(self, + dataset_root, + transforms, + mode='train', + train_file=None, + val_file=None, + get_trimap=True, + separator=' ', + key_del=None, + if_rssn=False): + super().__init__() + self.dataset_root = dataset_root + self.transforms = T.Compose(transforms) + self.mode = mode + self.get_trimap = get_trimap + self.separator = separator + self.key_del = key_del + self.if_rssn = if_rssn + + # check file + if mode == 'train' or mode == 'trainval': + if train_file is None: + raise ValueError( + "When `mode` is 'train' or 'trainval', `train_file must be provided!" + ) + if isinstance(train_file, str): + train_file = [train_file] + file_list = train_file + + if mode == 'val' or mode == 'trainval': + if val_file is None: + raise ValueError( + "When `mode` is 'val' or 'trainval', `val_file must be provided!" + ) + if isinstance(val_file, str): + val_file = [val_file] + file_list = val_file + + if mode == 'trainval': + file_list = train_file + val_file + + # read file + self.fg_bg_list = [] + for file in file_list: + file = os.path.join(dataset_root, file) + with open(file, 'r') as f: + lines = f.readlines() + for line in lines: + line = line.strip() + self.fg_bg_list.append(line) + if mode != 'val': + random.shuffle(self.fg_bg_list) + + def __getitem__(self, idx): + data = {} + fg_bg_file = self.fg_bg_list[idx] + fg_bg_file = fg_bg_file.split(self.separator) + data['img_name'] = fg_bg_file[0] # using in save prediction results + fg_file = os.path.join(self.dataset_root, fg_bg_file[0]) + alpha_file = fg_file.replace('/fg', '/alpha') + fg = cv2.imread(fg_file) + alpha = cv2.imread(alpha_file, 0) + data['alpha'] = alpha + data['gt_fields'] = [] + + # line is: fg [bg] [trimap] + if len(fg_bg_file) >= 2: + bg_file = os.path.join(self.dataset_root, fg_bg_file[1]) + bg = cv2.imread(bg_file) + data['img'], data['fg'], data['bg'] = self.composite(fg, alpha, bg) + if self.mode in ['train', 'trainval']: + data['gt_fields'].append('fg') + data['gt_fields'].append('bg') + data['gt_fields'].append('alpha') + if len(fg_bg_file) == 3 and self.get_trimap: + if self.mode == 'val': + trimap_path = os.path.join(self.dataset_root, fg_bg_file[2]) + if os.path.exists(trimap_path): + data['trimap'] = trimap_path + data['gt_fields'].append('trimap') + data['ori_trimap'] = cv2.imread(trimap_path, 0) + else: + raise FileNotFoundError( + 'trimap is not Found: {}'.format(fg_bg_file[2])) + else: + data['img'] = fg + if self.mode in ['train', 'trainval']: + data['fg'] = fg.copy() + data['bg'] = fg.copy() + data['gt_fields'].append('fg') + data['gt_fields'].append('bg') + data['gt_fields'].append('alpha') + + data['trans_info'] = [] # Record shape change information + + # Generate trimap from alpha if no trimap file provided + if self.get_trimap: + if 'trimap' not in data: + data['trimap'] = self.gen_trimap( + data['alpha'], mode=self.mode).astype('float32') + data['gt_fields'].append('trimap') + if self.mode == 'val': + data['ori_trimap'] = data['trimap'].copy() + + # Delete key which is not need + if self.key_del is not None: + for key in self.key_del: + if key in data.keys(): + data.pop(key) + if key in data['gt_fields']: + data['gt_fields'].remove(key) + data = self.transforms(data) + + # When evaluation, gt should not be transforms. + if self.mode == 'val': + data['gt_fields'].append('alpha') + + data['img'] = data['img'].astype('float32') + for key in data.get('gt_fields', []): + data[key] = data[key].astype('float32') + + if 'trimap' in data: + data['trimap'] = data['trimap'][np.newaxis, :, :] + if 'ori_trimap' in data: + data['ori_trimap'] = data['ori_trimap'][np.newaxis, :, :] + + data['alpha'] = data['alpha'][np.newaxis, :, :] / 255. + + return data + + def __len__(self): + return len(self.fg_bg_list) + + def composite(self, fg, alpha, ori_bg): + if self.if_rssn: + if np.random.rand() < 0.5: + fg = cv2.fastNlMeansDenoisingColored(fg, None, 3, 3, 7, 21) + ori_bg = cv2.fastNlMeansDenoisingColored(ori_bg, None, 3, 3, 7, + 21) + if np.random.rand() < 0.5: + radius = np.random.choice([19, 29, 39, 49, 59]) + ori_bg = cv2.GaussianBlur(ori_bg, (radius, radius), 0, 0) + fg_h, fg_w = fg.shape[:2] + ori_bg_h, ori_bg_w = ori_bg.shape[:2] + + wratio = fg_w / ori_bg_w + hratio = fg_h / ori_bg_h + ratio = wratio if wratio > hratio else hratio + + # Resize ori_bg if it is smaller than fg. + if ratio > 1: + resize_h = math.ceil(ori_bg_h * ratio) + resize_w = math.ceil(ori_bg_w * ratio) + bg = cv2.resize( + ori_bg, (resize_w, resize_h), interpolation=cv2.INTER_LINEAR) + else: + bg = ori_bg + + bg = bg[0:fg_h, 0:fg_w, :] + alpha = alpha / 255 + alpha = np.expand_dims(alpha, axis=2) + image = alpha * fg + (1 - alpha) * bg + image = image.astype(np.uint8) + return image, fg, bg + + @staticmethod + def gen_trimap(alpha, mode='train', eval_kernel=7): + if mode == 'train': + k_size = random.choice(range(2, 5)) + iterations = np.random.randint(5, 15) + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, + (k_size, k_size)) + dilated = cv2.dilate(alpha, kernel, iterations=iterations) + eroded = cv2.erode(alpha, kernel, iterations=iterations) + trimap = np.zeros(alpha.shape) + trimap.fill(128) + trimap[eroded > 254.5] = 255 + trimap[dilated < 0.5] = 0 + else: + k_size = eval_kernel + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, + (k_size, k_size)) + dilated = cv2.dilate(alpha, kernel) + trimap = np.zeros(alpha.shape) + trimap.fill(128) + trimap[alpha >= 250] = 255 + trimap[dilated <= 5] = 0 + + return trimap diff --git a/ppmatting/metrics/__init__.py b/ppmatting/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..836f0a973bf4331d36982252d47f7279e7c24752 --- /dev/null +++ b/ppmatting/metrics/__init__.py @@ -0,0 +1,3 @@ +from .metric import MSE, SAD, Grad, Conn + +metrics_class_dict = {'sad': SAD, 'mse': MSE, 'grad': Grad, 'conn': Conn} diff --git a/ppmatting/metrics/metric.py b/ppmatting/metrics/metric.py new file mode 100644 index 0000000000000000000000000000000000000000..2784dcf20fcffeadc326ad00d9b6a74d07ad58cf --- /dev/null +++ b/ppmatting/metrics/metric.py @@ -0,0 +1,278 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Grad and Conn is refer to https://github.com/yucornetto/MGMatting/blob/main/code-base/utils/evaluate.py +# Output of `Grad` is sightly different from the MATLAB version provided by Adobe (less than 0.1%) +# Output of `Conn` is smaller than the MATLAB version (~5%, maybe MATLAB has a different algorithm) +# So do not report results calculated by these functions in your paper. +# Evaluate your inference with the MATLAB file `DIM_evaluation_code/evaluate.m`. + +import cv2 +import numpy as np +from scipy.ndimage import convolve +from scipy.special import gamma +from skimage.measure import label + + +class MSE: + """ + Only calculate the unknown region if trimap provided. + """ + + def __init__(self): + self.mse_diffs = 0 + self.count = 0 + + def update(self, pred, gt, trimap=None): + """ + update metric. + Args: + pred (np.ndarray): The value range is [0., 255.]. + gt (np.ndarray): The value range is [0, 255]. + trimap (np.ndarray, optional) The value is in {0, 128, 255}. Default: None. + """ + if trimap is None: + trimap = np.ones_like(gt) * 128 + if not (pred.shape == gt.shape == trimap.shape): + raise ValueError( + 'The shape of `pred`, `gt` and `trimap` should be equal. ' + 'but they are {}, {} and {}'.format(pred.shape, gt.shape, + trimap.shape)) + pred[trimap == 0] = 0 + pred[trimap == 255] = 255 + + mask = trimap == 128 + pixels = float(mask.sum()) + pred = pred / 255. + gt = gt / 255. + diff = (pred - gt) * mask + mse_diff = (diff**2).sum() / pixels if pixels > 0 else 0 + + self.mse_diffs += mse_diff + self.count += 1 + + return mse_diff + + def evaluate(self): + mse = self.mse_diffs / self.count if self.count > 0 else 0 + return mse + + +class SAD: + """ + Only calculate the unknown region if trimap provided. + """ + + def __init__(self): + self.sad_diffs = 0 + self.count = 0 + + def update(self, pred, gt, trimap=None): + """ + update metric. + Args: + pred (np.ndarray): The value range is [0., 255.]. + gt (np.ndarray): The value range is [0., 255.]. + trimap (np.ndarray, optional)L The value is in {0, 128, 255}. Default: None. + """ + if trimap is None: + trimap = np.ones_like(gt) * 128 + if not (pred.shape == gt.shape == trimap.shape): + raise ValueError( + 'The shape of `pred`, `gt` and `trimap` should be equal. ' + 'but they are {}, {} and {}'.format(pred.shape, gt.shape, + trimap.shape)) + pred[trimap == 0] = 0 + pred[trimap == 255] = 255 + + mask = trimap == 128 + pred = pred / 255. + gt = gt / 255. + diff = (pred - gt) * mask + sad_diff = (np.abs(diff)).sum() + + sad_diff /= 1000 + self.sad_diffs += sad_diff + self.count += 1 + + return sad_diff + + def evaluate(self): + sad = self.sad_diffs / self.count if self.count > 0 else 0 + return sad + + +class Grad: + """ + Only calculate the unknown region if trimap provided. + Refer to: https://github.com/open-mlab/mmediting/blob/master/mmedit/core/evaluation/metrics.py + """ + + def __init__(self): + self.grad_diffs = 0 + self.count = 0 + + def gaussian(self, x, sigma): + return np.exp(-x**2 / (2 * sigma**2)) / (sigma * np.sqrt(2 * np.pi)) + + def dgaussian(self, x, sigma): + return -x * self.gaussian(x, sigma) / sigma**2 + + def gauss_filter(self, sigma, epsilon=1e-2): + half_size = np.ceil( + sigma * np.sqrt(-2 * np.log(np.sqrt(2 * np.pi) * sigma * epsilon))) + size = int(2 * half_size + 1) + + # create filter in x axis + filter_x = np.zeros((size, size)) + for i in range(size): + for j in range(size): + filter_x[i, j] = self.gaussian( + i - half_size, sigma) * self.dgaussian(j - half_size, sigma) + + # normalize filter + norm = np.sqrt((filter_x**2).sum()) + filter_x = filter_x / norm + filter_y = np.transpose(filter_x) + + return filter_x, filter_y + + def gauss_gradient(self, img, sigma): + filter_x, filter_y = self.gauss_filter(sigma) + img_filtered_x = cv2.filter2D( + img, -1, filter_x, borderType=cv2.BORDER_REPLICATE) + img_filtered_y = cv2.filter2D( + img, -1, filter_y, borderType=cv2.BORDER_REPLICATE) + return np.sqrt(img_filtered_x**2 + img_filtered_y**2) + + def update(self, pred, gt, trimap=None, sigma=1.4): + """ + update metric. + Args: + pred (np.ndarray): The value range is [0., 1.]. + gt (np.ndarray): The value range is [0, 255]. + trimap (np.ndarray, optional)L The value is in {0, 128, 255}. Default: None. + sigma (float, optional): Standard deviation of the gaussian kernel. Default: 1.4. + """ + if trimap is None: + trimap = np.ones_like(gt) * 128 + if not (pred.shape == gt.shape == trimap.shape): + raise ValueError( + 'The shape of `pred`, `gt` and `trimap` should be equal. ' + 'but they are {}, {} and {}'.format(pred.shape, gt.shape, + trimap.shape)) + pred[trimap == 0] = 0 + pred[trimap == 255] = 255 + + gt = gt.squeeze() + pred = pred.squeeze() + gt = gt.astype(np.float64) + pred = pred.astype(np.float64) + gt_normed = np.zeros_like(gt) + pred_normed = np.zeros_like(pred) + cv2.normalize(gt, gt_normed, 1., 0., cv2.NORM_MINMAX) + cv2.normalize(pred, pred_normed, 1., 0., cv2.NORM_MINMAX) + + gt_grad = self.gauss_gradient(gt_normed, sigma).astype(np.float32) + pred_grad = self.gauss_gradient(pred_normed, sigma).astype(np.float32) + + grad_diff = ((gt_grad - pred_grad)**2 * (trimap == 128)).sum() + + grad_diff /= 1000 + self.grad_diffs += grad_diff + self.count += 1 + + return grad_diff + + def evaluate(self): + grad = self.grad_diffs / self.count if self.count > 0 else 0 + return grad + + +class Conn: + """ + Only calculate the unknown region if trimap provided. + Refer to: Refer to: https://github.com/open-mlab/mmediting/blob/master/mmedit/core/evaluation/metrics.py + """ + + def __init__(self): + self.conn_diffs = 0 + self.count = 0 + + def update(self, pred, gt, trimap=None, step=0.1): + """ + update metric. + Args: + pred (np.ndarray): The value range is [0., 1.]. + gt (np.ndarray): The value range is [0, 255]. + trimap (np.ndarray, optional)L The value is in {0, 128, 255}. Default: None. + step (float, optional): Step of threshold when computing intersection between + `gt` and `pred`. Default: 0.1. + """ + if trimap is None: + trimap = np.ones_like(gt) * 128 + if not (pred.shape == gt.shape == trimap.shape): + raise ValueError( + 'The shape of `pred`, `gt` and `trimap` should be equal. ' + 'but they are {}, {} and {}'.format(pred.shape, gt.shape, + trimap.shape)) + pred[trimap == 0] = 0 + pred[trimap == 255] = 255 + + gt = gt.squeeze() + pred = pred.squeeze() + gt = gt.astype(np.float32) / 255 + pred = pred.astype(np.float32) / 255 + + thresh_steps = np.arange(0, 1 + step, step) + round_down_map = -np.ones_like(gt) + for i in range(1, len(thresh_steps)): + gt_thresh = gt >= thresh_steps[i] + pred_thresh = pred >= thresh_steps[i] + intersection = (gt_thresh & pred_thresh).astype(np.uint8) + + # connected components + _, output, stats, _ = cv2.connectedComponentsWithStats( + intersection, connectivity=4) + # start from 1 in dim 0 to exclude background + size = stats[1:, -1] + + # largest connected component of the intersection + omega = np.zeros_like(gt) + if len(size) != 0: + max_id = np.argmax(size) + # plus one to include background + omega[output == max_id + 1] = 1 + + mask = (round_down_map == -1) & (omega == 0) + round_down_map[mask] = thresh_steps[i - 1] + round_down_map[round_down_map == -1] = 1 + + gt_diff = gt - round_down_map + pred_diff = pred - round_down_map + # only calculate difference larger than or equal to 0.15 + gt_phi = 1 - gt_diff * (gt_diff >= 0.15) + pred_phi = 1 - pred_diff * (pred_diff >= 0.15) + + conn_diff = np.sum(np.abs(gt_phi - pred_phi) * (trimap == 128)) + + conn_diff /= 1000 + self.conn_diffs += conn_diff + self.count += 1 + + return conn_diff + + def evaluate(self): + conn = self.conn_diffs / self.count if self.count > 0 else 0 + return conn diff --git a/ppmatting/ml/__init__.py b/ppmatting/ml/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..612dff101f358f74db3eca601f0b9573ca6d93cb --- /dev/null +++ b/ppmatting/ml/__init__.py @@ -0,0 +1 @@ +from .methods import CloseFormMatting, KNNMatting, LearningBasedMatting, FastMatting, RandomWalksMatting diff --git a/ppmatting/ml/methods.py b/ppmatting/ml/methods.py new file mode 100644 index 0000000000000000000000000000000000000000..61d5fea2475552c14d29fe44fd08cf436e55bdbd --- /dev/null +++ b/ppmatting/ml/methods.py @@ -0,0 +1,97 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pymatting +from paddleseg.cvlibs import manager + + +class BaseMLMatting(object): + def __init__(self, alpha_estimator, **kargs): + self.alpha_estimator = alpha_estimator + self.kargs = kargs + + def __call__(self, image, trimap): + image = self.__to_float64(image) + trimap = self.__to_float64(trimap) + alpha_matte = self.alpha_estimator(image, trimap, **self.kargs) + return alpha_matte + + def __to_float64(self, x): + x_dtype = x.dtype + assert x_dtype in ["float32", "float64"] + x = x.astype("float64") + return x + + +@manager.MODELS.add_component +class CloseFormMatting(BaseMLMatting): + def __init__(self, **kargs): + cf_alpha_estimator = pymatting.estimate_alpha_cf + super().__init__(cf_alpha_estimator, **kargs) + + +@manager.MODELS.add_component +class KNNMatting(BaseMLMatting): + def __init__(self, **kargs): + knn_alpha_estimator = pymatting.estimate_alpha_knn + super().__init__(knn_alpha_estimator, **kargs) + + +@manager.MODELS.add_component +class LearningBasedMatting(BaseMLMatting): + def __init__(self, **kargs): + lbdm_alpha_estimator = pymatting.estimate_alpha_lbdm + super().__init__(lbdm_alpha_estimator, **kargs) + + +@manager.MODELS.add_component +class FastMatting(BaseMLMatting): + def __init__(self, **kargs): + lkm_alpha_estimator = pymatting.estimate_alpha_lkm + super().__init__(lkm_alpha_estimator, **kargs) + + +@manager.MODELS.add_component +class RandomWalksMatting(BaseMLMatting): + def __init__(self, **kargs): + rw_alpha_estimator = pymatting.estimate_alpha_rw + super().__init__(rw_alpha_estimator, **kargs) + + +if __name__ == "__main__": + from pymatting.util.util import load_image, save_image, stack_images + from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml + import cv2 + + root = "/mnt/liuyi22/PaddlePaddle/PaddleSeg/Matting/data/examples/" + image_path = root + "lemur.png" + trimap_path = root + "lemur_trimap.png" + cutout_path = root + "lemur_cutout.png" + image = cv2.cvtColor( + cv2.imread(image_path).astype("float64"), cv2.COLOR_BGR2RGB) / 255.0 + + cv2.imwrite("image.png", (image * 255).astype('uint8')) + trimap = load_image(trimap_path, "GRAY") + print(image.shape, trimap.shape) + print(image.dtype, trimap.dtype) + cf = CloseFormMatting() + alpha = cf(image, trimap) + + # alpha = pymatting.estimate_alpha_lkm(image, trimap) + + foreground = estimate_foreground_ml(image, alpha) + + cutout = stack_images(foreground, alpha) + + save_image(cutout_path, cutout) diff --git a/ppmatting/models/__init__.py b/ppmatting/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d446649bc75b44f5ff3f9183e22f057f128b5fa2 --- /dev/null +++ b/ppmatting/models/__init__.py @@ -0,0 +1,7 @@ +from .backbone import * +from .losses import * +from .modnet import MODNet +from .human_matting import HumanMatting +from .dim import DIM +from .ppmatting import PPMatting +from .gca import GCABaseline, GCA diff --git a/ppmatting/models/backbone/__init__.py b/ppmatting/models/backbone/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b08005b31477e57488132cd2f5d3692c6e687b4f --- /dev/null +++ b/ppmatting/models/backbone/__init__.py @@ -0,0 +1,5 @@ +from .mobilenet_v2 import * +from .hrnet import * +from .resnet_vd import * +from .vgg import * +from .gca_enc import * \ No newline at end of file diff --git a/ppmatting/models/backbone/gca_enc.py b/ppmatting/models/backbone/gca_enc.py new file mode 100644 index 0000000000000000000000000000000000000000..2afeb5df8c398d89ac1d4fe8e411571afebec5b6 --- /dev/null +++ b/ppmatting/models/backbone/gca_enc.py @@ -0,0 +1,395 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# The gca code was heavily based on https://github.com/Yaoyi-Li/GCA-Matting +# and https://github.com/open-mmlab/mmediting + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddleseg.cvlibs import manager, param_init +from paddleseg.utils import utils + +from ppmatting.models.layers import GuidedCxtAtten + + +class ResNet_D(nn.Layer): + def __init__(self, + input_channels, + layers, + late_downsample=False, + pretrained=None): + + super().__init__() + + self.pretrained = pretrained + + self._norm_layer = nn.BatchNorm + self.inplanes = 64 + self.late_downsample = late_downsample + self.midplanes = 64 if late_downsample else 32 + self.start_stride = [1, 2, 1, 2] if late_downsample else [2, 1, 2, 1] + self.conv1 = nn.utils.spectral_norm( + nn.Conv2D( + input_channels, + 32, + kernel_size=3, + stride=self.start_stride[0], + padding=1, + bias_attr=False)) + self.conv2 = nn.utils.spectral_norm( + nn.Conv2D( + 32, + self.midplanes, + kernel_size=3, + stride=self.start_stride[1], + padding=1, + bias_attr=False)) + self.conv3 = nn.utils.spectral_norm( + nn.Conv2D( + self.midplanes, + self.inplanes, + kernel_size=3, + stride=self.start_stride[2], + padding=1, + bias_attr=False)) + self.bn1 = self._norm_layer(32) + self.bn2 = self._norm_layer(self.midplanes) + self.bn3 = self._norm_layer(self.inplanes) + self.activation = nn.ReLU() + self.layer1 = self._make_layer( + BasicBlock, 64, layers[0], stride=self.start_stride[3]) + self.layer2 = self._make_layer(BasicBlock, 128, layers[1], stride=2) + self.layer3 = self._make_layer(BasicBlock, 256, layers[2], stride=2) + self.layer_bottleneck = self._make_layer( + BasicBlock, 512, layers[3], stride=2) + + self.init_weight() + + def _make_layer(self, block, planes, block_num, stride=1): + if block_num == 0: + return nn.Sequential(nn.Identity()) + norm_layer = self._norm_layer + downsample = None + if stride != 1: + downsample = nn.Sequential( + nn.AvgPool2D(2, stride), + nn.utils.spectral_norm( + conv1x1(self.inplanes, planes * block.expansion)), + norm_layer(planes * block.expansion), ) + elif self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.utils.spectral_norm( + conv1x1(self.inplanes, planes * block.expansion, stride)), + norm_layer(planes * block.expansion), ) + + layers = [block(self.inplanes, planes, stride, downsample, norm_layer)] + self.inplanes = planes * block.expansion + for _ in range(1, block_num): + layers.append(block(self.inplanes, planes, norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.activation(x) + x = self.conv2(x) + x = self.bn2(x) + x1 = self.activation(x) # N x 32 x 256 x 256 + x = self.conv3(x1) + x = self.bn3(x) + x2 = self.activation(x) # N x 64 x 128 x 128 + + x3 = self.layer1(x2) # N x 64 x 128 x 128 + x4 = self.layer2(x3) # N x 128 x 64 x 64 + x5 = self.layer3(x4) # N x 256 x 32 x 32 + x = self.layer_bottleneck(x5) # N x 512 x 16 x 16 + + return x, (x1, x2, x3, x4, x5) + + def init_weight(self): + + for layer in self.sublayers(): + if isinstance(layer, nn.Conv2D): + + if hasattr(layer, "weight_orig"): + param = layer.weight_orig + else: + param = layer.weight + param_init.xavier_uniform(param) + + elif isinstance(layer, (nn.BatchNorm, nn.SyncBatchNorm)): + param_init.constant_init(layer.weight, value=1.0) + param_init.constant_init(layer.bias, value=0.0) + + elif isinstance(layer, BasicBlock): + param_init.constant_init(layer.bn2.weight, value=0.0) + + if self.pretrained is not None: + utils.load_pretrained_model(self, self.pretrained) + + +@manager.MODELS.add_component +class ResShortCut_D(ResNet_D): + def __init__(self, + input_channels, + layers, + late_downsample=False, + pretrained=None): + super().__init__( + input_channels, + layers, + late_downsample=late_downsample, + pretrained=pretrained) + + self.shortcut_inplane = [input_channels, self.midplanes, 64, 128, 256] + self.shortcut_plane = [32, self.midplanes, 64, 128, 256] + + self.shortcut = nn.LayerList() + for stage, inplane in enumerate(self.shortcut_inplane): + self.shortcut.append( + self._make_shortcut(inplane, self.shortcut_plane[stage])) + + def _make_shortcut(self, inplane, planes): + return nn.Sequential( + nn.utils.spectral_norm( + nn.Conv2D( + inplane, planes, kernel_size=3, padding=1, + bias_attr=False)), + nn.ReLU(), + self._norm_layer(planes), + nn.utils.spectral_norm( + nn.Conv2D( + planes, planes, kernel_size=3, padding=1, bias_attr=False)), + nn.ReLU(), + self._norm_layer(planes)) + + def forward(self, x): + + out = self.conv1(x) + out = self.bn1(out) + out = self.activation(out) + out = self.conv2(out) + out = self.bn2(out) + x1 = self.activation(out) # N x 32 x 256 x 256 + out = self.conv3(x1) + out = self.bn3(out) + out = self.activation(out) + + x2 = self.layer1(out) # N x 64 x 128 x 128 + x3 = self.layer2(x2) # N x 128 x 64 x 64 + x4 = self.layer3(x3) # N x 256 x 32 x 32 + out = self.layer_bottleneck(x4) # N x 512 x 16 x 16 + + fea1 = self.shortcut[0](x) # input image and trimap + fea2 = self.shortcut[1](x1) + fea3 = self.shortcut[2](x2) + fea4 = self.shortcut[3](x3) + fea5 = self.shortcut[4](x4) + + return out, { + 'shortcut': (fea1, fea2, fea3, fea4, fea5), + 'image': x[:, :3, ...] + } + + +@manager.MODELS.add_component +class ResGuidedCxtAtten(ResNet_D): + def __init__(self, + input_channels, + layers, + late_downsample=False, + pretrained=None): + super().__init__( + input_channels, + layers, + late_downsample=late_downsample, + pretrained=pretrained) + self.input_channels = input_channels + self.shortcut_inplane = [input_channels, self.midplanes, 64, 128, 256] + self.shortcut_plane = [32, self.midplanes, 64, 128, 256] + + self.shortcut = nn.LayerList() + for stage, inplane in enumerate(self.shortcut_inplane): + self.shortcut.append( + self._make_shortcut(inplane, self.shortcut_plane[stage])) + + self.guidance_head = nn.Sequential( + nn.Pad2D( + 1, mode="reflect"), + nn.utils.spectral_norm( + nn.Conv2D( + 3, 16, kernel_size=3, padding=0, stride=2, + bias_attr=False)), + nn.ReLU(), + self._norm_layer(16), + nn.Pad2D( + 1, mode="reflect"), + nn.utils.spectral_norm( + nn.Conv2D( + 16, 32, kernel_size=3, padding=0, stride=2, + bias_attr=False)), + nn.ReLU(), + self._norm_layer(32), + nn.Pad2D( + 1, mode="reflect"), + nn.utils.spectral_norm( + nn.Conv2D( + 32, + 128, + kernel_size=3, + padding=0, + stride=2, + bias_attr=False)), + nn.ReLU(), + self._norm_layer(128)) + + self.gca = GuidedCxtAtten(128, 128) + + self.init_weight() + + def init_weight(self): + + for layer in self.sublayers(): + if isinstance(layer, nn.Conv2D): + initializer = nn.initializer.XavierUniform() + if hasattr(layer, "weight_orig"): + param = layer.weight_orig + else: + param = layer.weight + initializer(param, param.block) + + elif isinstance(layer, (nn.BatchNorm, nn.SyncBatchNorm)): + param_init.constant_init(layer.weight, value=1.0) + param_init.constant_init(layer.bias, value=0.0) + + elif isinstance(layer, BasicBlock): + param_init.constant_init(layer.bn2.weight, value=0.0) + + if self.pretrained is not None: + utils.load_pretrained_model(self, self.pretrained) + + def _make_shortcut(self, inplane, planes): + return nn.Sequential( + nn.utils.spectral_norm( + nn.Conv2D( + inplane, planes, kernel_size=3, padding=1, + bias_attr=False)), + nn.ReLU(), + self._norm_layer(planes), + nn.utils.spectral_norm( + nn.Conv2D( + planes, planes, kernel_size=3, padding=1, bias_attr=False)), + nn.ReLU(), + self._norm_layer(planes)) + + def forward(self, x): + + out = self.conv1(x) + out = self.bn1(out) + out = self.activation(out) + out = self.conv2(out) + out = self.bn2(out) + x1 = self.activation(out) # N x 32 x 256 x 256 + out = self.conv3(x1) + out = self.bn3(out) + out = self.activation(out) + + im_fea = self.guidance_head( + x[:, :3, ...]) # downsample origin image and extract features + if self.input_channels == 6: + unknown = F.interpolate( + x[:, 4:5, ...], scale_factor=1 / 8, mode='nearest') + else: + unknown = x[:, 3:, ...].equal(paddle.to_tensor([1.])) + unknown = paddle.cast(unknown, dtype='float32') + unknown = F.interpolate(unknown, scale_factor=1 / 8, mode='nearest') + + x2 = self.layer1(out) # N x 64 x 128 x 128 + x3 = self.layer2(x2) # N x 128 x 64 x 64 + x3 = self.gca(im_fea, x3, unknown) # contextual attention + x4 = self.layer3(x3) # N x 256 x 32 x 32 + out = self.layer_bottleneck(x4) # N x 512 x 16 x 16 + + fea1 = self.shortcut[0](x) # input image and trimap + fea2 = self.shortcut[1](x1) + fea3 = self.shortcut[2](x2) + fea4 = self.shortcut[3](x3) + fea5 = self.shortcut[4](x4) + + return out, { + 'shortcut': (fea1, fea2, fea3, fea4, fea5), + 'image_fea': im_fea, + 'unknown': unknown, + } + + +class BasicBlock(nn.Layer): + expansion = 1 + + def __init__(self, + inplanes, + planes, + stride=1, + downsample=None, + norm_layer=None): + super().__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = nn.utils.spectral_norm(conv3x3(inplanes, planes, stride)) + self.bn1 = norm_layer(planes) + self.activation = nn.ReLU() + self.conv2 = nn.utils.spectral_norm(conv3x3(planes, planes)) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.activation(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.activation(out) + + return out + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2D( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias_attr=False, + dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2D( + in_planes, out_planes, kernel_size=1, stride=stride, bias_attr=False) diff --git a/ppmatting/models/backbone/hrnet.py b/ppmatting/models/backbone/hrnet.py new file mode 100644 index 0000000000000000000000000000000000000000..96e23a77e656142a97c573feb501f983aecebbef --- /dev/null +++ b/ppmatting/models/backbone/hrnet.py @@ -0,0 +1,835 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from paddleseg.cvlibs import manager, param_init +from paddleseg.models import layers +from paddleseg.utils import utils + +__all__ = [ + "HRNet_W18_Small_V1", "HRNet_W18_Small_V2", "HRNet_W18", "HRNet_W30", + "HRNet_W32", "HRNet_W40", "HRNet_W44", "HRNet_W48", "HRNet_W60", "HRNet_W64" +] + + +class HRNet(nn.Layer): + """ + The HRNet implementation based on PaddlePaddle. + + The original article refers to + Jingdong Wang, et, al. "HRNet:Deep High-Resolution Representation Learning for Visual Recognition" + (https://arxiv.org/pdf/1908.07919.pdf). + + Args: + pretrained (str, optional): The path of pretrained model. + stage1_num_modules (int, optional): Number of modules for stage1. Default 1. + stage1_num_blocks (list, optional): Number of blocks per module for stage1. Default (4). + stage1_num_channels (list, optional): Number of channels per branch for stage1. Default (64). + stage2_num_modules (int, optional): Number of modules for stage2. Default 1. + stage2_num_blocks (list, optional): Number of blocks per module for stage2. Default (4, 4). + stage2_num_channels (list, optional): Number of channels per branch for stage2. Default (18, 36). + stage3_num_modules (int, optional): Number of modules for stage3. Default 4. + stage3_num_blocks (list, optional): Number of blocks per module for stage3. Default (4, 4, 4). + stage3_num_channels (list, optional): Number of channels per branch for stage3. Default [18, 36, 72). + stage4_num_modules (int, optional): Number of modules for stage4. Default 3. + stage4_num_blocks (list, optional): Number of blocks per module for stage4. Default (4, 4, 4, 4). + stage4_num_channels (list, optional): Number of channels per branch for stage4. Default (18, 36, 72. 144). + has_se (bool, optional): Whether to use Squeeze-and-Excitation module. Default False. + align_corners (bool, optional): An argument of F.interpolate. It should be set to False when the feature size is even, + e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False. + """ + + def __init__(self, + input_channels=3, + pretrained=None, + stage1_num_modules=1, + stage1_num_blocks=(4, ), + stage1_num_channels=(64, ), + stage2_num_modules=1, + stage2_num_blocks=(4, 4), + stage2_num_channels=(18, 36), + stage3_num_modules=4, + stage3_num_blocks=(4, 4, 4), + stage3_num_channels=(18, 36, 72), + stage4_num_modules=3, + stage4_num_blocks=(4, 4, 4, 4), + stage4_num_channels=(18, 36, 72, 144), + has_se=False, + align_corners=False, + padding_same=True): + super(HRNet, self).__init__() + self.pretrained = pretrained + self.stage1_num_modules = stage1_num_modules + self.stage1_num_blocks = stage1_num_blocks + self.stage1_num_channels = stage1_num_channels + self.stage2_num_modules = stage2_num_modules + self.stage2_num_blocks = stage2_num_blocks + self.stage2_num_channels = stage2_num_channels + self.stage3_num_modules = stage3_num_modules + self.stage3_num_blocks = stage3_num_blocks + self.stage3_num_channels = stage3_num_channels + self.stage4_num_modules = stage4_num_modules + self.stage4_num_blocks = stage4_num_blocks + self.stage4_num_channels = stage4_num_channels + self.has_se = has_se + self.align_corners = align_corners + + self.feat_channels = [i for i in stage4_num_channels] + self.feat_channels = [64] + self.feat_channels + + self.conv_layer1_1 = layers.ConvBNReLU( + in_channels=input_channels, + out_channels=64, + kernel_size=3, + stride=2, + padding=1 if not padding_same else 'same', + bias_attr=False) + + self.conv_layer1_2 = layers.ConvBNReLU( + in_channels=64, + out_channels=64, + kernel_size=3, + stride=2, + padding=1 if not padding_same else 'same', + bias_attr=False) + + self.la1 = Layer1( + num_channels=64, + num_blocks=self.stage1_num_blocks[0], + num_filters=self.stage1_num_channels[0], + has_se=has_se, + name="layer2", + padding_same=padding_same) + + self.tr1 = TransitionLayer( + in_channels=[self.stage1_num_channels[0] * 4], + out_channels=self.stage2_num_channels, + name="tr1", + padding_same=padding_same) + + self.st2 = Stage( + num_channels=self.stage2_num_channels, + num_modules=self.stage2_num_modules, + num_blocks=self.stage2_num_blocks, + num_filters=self.stage2_num_channels, + has_se=self.has_se, + name="st2", + align_corners=align_corners, + padding_same=padding_same) + + self.tr2 = TransitionLayer( + in_channels=self.stage2_num_channels, + out_channels=self.stage3_num_channels, + name="tr2", + padding_same=padding_same) + self.st3 = Stage( + num_channels=self.stage3_num_channels, + num_modules=self.stage3_num_modules, + num_blocks=self.stage3_num_blocks, + num_filters=self.stage3_num_channels, + has_se=self.has_se, + name="st3", + align_corners=align_corners, + padding_same=padding_same) + + self.tr3 = TransitionLayer( + in_channels=self.stage3_num_channels, + out_channels=self.stage4_num_channels, + name="tr3", + padding_same=padding_same) + self.st4 = Stage( + num_channels=self.stage4_num_channels, + num_modules=self.stage4_num_modules, + num_blocks=self.stage4_num_blocks, + num_filters=self.stage4_num_channels, + has_se=self.has_se, + name="st4", + align_corners=align_corners, + padding_same=padding_same) + + self.init_weight() + + def forward(self, x): + feat_list = [] + conv1 = self.conv_layer1_1(x) + feat_list.append(conv1) + conv2 = self.conv_layer1_2(conv1) + + la1 = self.la1(conv2) + + tr1 = self.tr1([la1]) + st2 = self.st2(tr1) + + tr2 = self.tr2(st2) + st3 = self.st3(tr2) + + tr3 = self.tr3(st3) + st4 = self.st4(tr3) + + feat_list = feat_list + st4 + + return feat_list + + def init_weight(self): + for layer in self.sublayers(): + if isinstance(layer, nn.Conv2D): + param_init.normal_init(layer.weight, std=0.001) + elif isinstance(layer, (nn.BatchNorm, nn.SyncBatchNorm)): + param_init.constant_init(layer.weight, value=1.0) + param_init.constant_init(layer.bias, value=0.0) + if self.pretrained is not None: + utils.load_pretrained_model(self, self.pretrained) + + +class Layer1(nn.Layer): + def __init__(self, + num_channels, + num_filters, + num_blocks, + has_se=False, + name=None, + padding_same=True): + super(Layer1, self).__init__() + + self.bottleneck_block_list = [] + + for i in range(num_blocks): + bottleneck_block = self.add_sublayer( + "bb_{}_{}".format(name, i + 1), + BottleneckBlock( + num_channels=num_channels if i == 0 else num_filters * 4, + num_filters=num_filters, + has_se=has_se, + stride=1, + downsample=True if i == 0 else False, + name=name + '_' + str(i + 1), + padding_same=padding_same)) + self.bottleneck_block_list.append(bottleneck_block) + + def forward(self, x): + conv = x + for block_func in self.bottleneck_block_list: + conv = block_func(conv) + return conv + + +class TransitionLayer(nn.Layer): + def __init__(self, in_channels, out_channels, name=None, padding_same=True): + super(TransitionLayer, self).__init__() + + num_in = len(in_channels) + num_out = len(out_channels) + self.conv_bn_func_list = [] + for i in range(num_out): + residual = None + if i < num_in: + if in_channels[i] != out_channels[i]: + residual = self.add_sublayer( + "transition_{}_layer_{}".format(name, i + 1), + layers.ConvBNReLU( + in_channels=in_channels[i], + out_channels=out_channels[i], + kernel_size=3, + padding=1 if not padding_same else 'same', + bias_attr=False)) + else: + residual = self.add_sublayer( + "transition_{}_layer_{}".format(name, i + 1), + layers.ConvBNReLU( + in_channels=in_channels[-1], + out_channels=out_channels[i], + kernel_size=3, + stride=2, + padding=1 if not padding_same else 'same', + bias_attr=False)) + self.conv_bn_func_list.append(residual) + + def forward(self, x): + outs = [] + for idx, conv_bn_func in enumerate(self.conv_bn_func_list): + if conv_bn_func is None: + outs.append(x[idx]) + else: + if idx < len(x): + outs.append(conv_bn_func(x[idx])) + else: + outs.append(conv_bn_func(x[-1])) + return outs + + +class Branches(nn.Layer): + def __init__(self, + num_blocks, + in_channels, + out_channels, + has_se=False, + name=None, + padding_same=True): + super(Branches, self).__init__() + + self.basic_block_list = [] + + for i in range(len(out_channels)): + self.basic_block_list.append([]) + for j in range(num_blocks[i]): + in_ch = in_channels[i] if j == 0 else out_channels[i] + basic_block_func = self.add_sublayer( + "bb_{}_branch_layer_{}_{}".format(name, i + 1, j + 1), + BasicBlock( + num_channels=in_ch, + num_filters=out_channels[i], + has_se=has_se, + name=name + '_branch_layer_' + str(i + 1) + '_' + + str(j + 1), + padding_same=padding_same)) + self.basic_block_list[i].append(basic_block_func) + + def forward(self, x): + outs = [] + for idx, input in enumerate(x): + conv = input + for basic_block_func in self.basic_block_list[idx]: + conv = basic_block_func(conv) + outs.append(conv) + return outs + + +class BottleneckBlock(nn.Layer): + def __init__(self, + num_channels, + num_filters, + has_se, + stride=1, + downsample=False, + name=None, + padding_same=True): + super(BottleneckBlock, self).__init__() + + self.has_se = has_se + self.downsample = downsample + + self.conv1 = layers.ConvBNReLU( + in_channels=num_channels, + out_channels=num_filters, + kernel_size=1, + bias_attr=False) + + self.conv2 = layers.ConvBNReLU( + in_channels=num_filters, + out_channels=num_filters, + kernel_size=3, + stride=stride, + padding=1 if not padding_same else 'same', + bias_attr=False) + + self.conv3 = layers.ConvBN( + in_channels=num_filters, + out_channels=num_filters * 4, + kernel_size=1, + bias_attr=False) + + if self.downsample: + self.conv_down = layers.ConvBN( + in_channels=num_channels, + out_channels=num_filters * 4, + kernel_size=1, + bias_attr=False) + + if self.has_se: + self.se = SELayer( + num_channels=num_filters * 4, + num_filters=num_filters * 4, + reduction_ratio=16, + name=name + '_fc') + + self.add = layers.Add() + self.relu = layers.Activation("relu") + + def forward(self, x): + residual = x + conv1 = self.conv1(x) + conv2 = self.conv2(conv1) + conv3 = self.conv3(conv2) + + if self.downsample: + residual = self.conv_down(x) + + if self.has_se: + conv3 = self.se(conv3) + + y = self.add(conv3, residual) + y = self.relu(y) + return y + + +class BasicBlock(nn.Layer): + def __init__(self, + num_channels, + num_filters, + stride=1, + has_se=False, + downsample=False, + name=None, + padding_same=True): + super(BasicBlock, self).__init__() + + self.has_se = has_se + self.downsample = downsample + + self.conv1 = layers.ConvBNReLU( + in_channels=num_channels, + out_channels=num_filters, + kernel_size=3, + stride=stride, + padding=1 if not padding_same else 'same', + bias_attr=False) + self.conv2 = layers.ConvBN( + in_channels=num_filters, + out_channels=num_filters, + kernel_size=3, + padding=1 if not padding_same else 'same', + bias_attr=False) + + if self.downsample: + self.conv_down = layers.ConvBNReLU( + in_channels=num_channels, + out_channels=num_filters, + kernel_size=1, + bias_attr=False) + + if self.has_se: + self.se = SELayer( + num_channels=num_filters, + num_filters=num_filters, + reduction_ratio=16, + name=name + '_fc') + + self.add = layers.Add() + self.relu = layers.Activation("relu") + + def forward(self, x): + residual = x + conv1 = self.conv1(x) + conv2 = self.conv2(conv1) + + if self.downsample: + residual = self.conv_down(x) + + if self.has_se: + conv2 = self.se(conv2) + + y = self.add(conv2, residual) + y = self.relu(y) + return y + + +class SELayer(nn.Layer): + def __init__(self, num_channels, num_filters, reduction_ratio, name=None): + super(SELayer, self).__init__() + + self.pool2d_gap = nn.AdaptiveAvgPool2D(1) + + self._num_channels = num_channels + + med_ch = int(num_channels / reduction_ratio) + stdv = 1.0 / math.sqrt(num_channels * 1.0) + self.squeeze = nn.Linear( + num_channels, + med_ch, + weight_attr=paddle.ParamAttr( + initializer=nn.initializer.Uniform(-stdv, stdv))) + + stdv = 1.0 / math.sqrt(med_ch * 1.0) + self.excitation = nn.Linear( + med_ch, + num_filters, + weight_attr=paddle.ParamAttr( + initializer=nn.initializer.Uniform(-stdv, stdv))) + + def forward(self, x): + pool = self.pool2d_gap(x) + pool = paddle.reshape(pool, shape=[-1, self._num_channels]) + squeeze = self.squeeze(pool) + squeeze = F.relu(squeeze) + excitation = self.excitation(squeeze) + excitation = F.sigmoid(excitation) + excitation = paddle.reshape( + excitation, shape=[-1, self._num_channels, 1, 1]) + out = x * excitation + return out + + +class Stage(nn.Layer): + def __init__(self, + num_channels, + num_modules, + num_blocks, + num_filters, + has_se=False, + multi_scale_output=True, + name=None, + align_corners=False, + padding_same=True): + super(Stage, self).__init__() + + self._num_modules = num_modules + + self.stage_func_list = [] + for i in range(num_modules): + if i == num_modules - 1 and not multi_scale_output: + stage_func = self.add_sublayer( + "stage_{}_{}".format(name, i + 1), + HighResolutionModule( + num_channels=num_channels, + num_blocks=num_blocks, + num_filters=num_filters, + has_se=has_se, + multi_scale_output=False, + name=name + '_' + str(i + 1), + align_corners=align_corners, + padding_same=padding_same)) + else: + stage_func = self.add_sublayer( + "stage_{}_{}".format(name, i + 1), + HighResolutionModule( + num_channels=num_channels, + num_blocks=num_blocks, + num_filters=num_filters, + has_se=has_se, + name=name + '_' + str(i + 1), + align_corners=align_corners, + padding_same=padding_same)) + + self.stage_func_list.append(stage_func) + + def forward(self, x): + out = x + for idx in range(self._num_modules): + out = self.stage_func_list[idx](out) + return out + + +class HighResolutionModule(nn.Layer): + def __init__(self, + num_channels, + num_blocks, + num_filters, + has_se=False, + multi_scale_output=True, + name=None, + align_corners=False, + padding_same=True): + super(HighResolutionModule, self).__init__() + + self.branches_func = Branches( + num_blocks=num_blocks, + in_channels=num_channels, + out_channels=num_filters, + has_se=has_se, + name=name, + padding_same=padding_same) + + self.fuse_func = FuseLayers( + in_channels=num_filters, + out_channels=num_filters, + multi_scale_output=multi_scale_output, + name=name, + align_corners=align_corners, + padding_same=padding_same) + + def forward(self, x): + out = self.branches_func(x) + out = self.fuse_func(out) + return out + + +class FuseLayers(nn.Layer): + def __init__(self, + in_channels, + out_channels, + multi_scale_output=True, + name=None, + align_corners=False, + padding_same=True): + super(FuseLayers, self).__init__() + + self._actual_ch = len(in_channels) if multi_scale_output else 1 + self._in_channels = in_channels + self.align_corners = align_corners + + self.residual_func_list = [] + for i in range(self._actual_ch): + for j in range(len(in_channels)): + if j > i: + residual_func = self.add_sublayer( + "residual_{}_layer_{}_{}".format(name, i + 1, j + 1), + layers.ConvBN( + in_channels=in_channels[j], + out_channels=out_channels[i], + kernel_size=1, + bias_attr=False)) + self.residual_func_list.append(residual_func) + elif j < i: + pre_num_filters = in_channels[j] + for k in range(i - j): + if k == i - j - 1: + residual_func = self.add_sublayer( + "residual_{}_layer_{}_{}_{}".format( + name, i + 1, j + 1, k + 1), + layers.ConvBN( + in_channels=pre_num_filters, + out_channels=out_channels[i], + kernel_size=3, + stride=2, + padding=1 if not padding_same else 'same', + bias_attr=False)) + pre_num_filters = out_channels[i] + else: + residual_func = self.add_sublayer( + "residual_{}_layer_{}_{}_{}".format( + name, i + 1, j + 1, k + 1), + layers.ConvBNReLU( + in_channels=pre_num_filters, + out_channels=out_channels[j], + kernel_size=3, + stride=2, + padding=1 if not padding_same else 'same', + bias_attr=False)) + pre_num_filters = out_channels[j] + self.residual_func_list.append(residual_func) + + def forward(self, x): + outs = [] + residual_func_idx = 0 + for i in range(self._actual_ch): + residual = x[i] + residual_shape = paddle.shape(residual)[-2:] + for j in range(len(self._in_channels)): + if j > i: + y = self.residual_func_list[residual_func_idx](x[j]) + residual_func_idx += 1 + + y = F.interpolate( + y, + residual_shape, + mode='bilinear', + align_corners=self.align_corners) + residual = residual + y + elif j < i: + y = x[j] + for k in range(i - j): + y = self.residual_func_list[residual_func_idx](y) + residual_func_idx += 1 + + residual = residual + y + + residual = F.relu(residual) + outs.append(residual) + + return outs + + +@manager.BACKBONES.add_component +def HRNet_W18_Small_V1(**kwargs): + model = HRNet( + stage1_num_modules=1, + stage1_num_blocks=[1], + stage1_num_channels=[32], + stage2_num_modules=1, + stage2_num_blocks=[2, 2], + stage2_num_channels=[16, 32], + stage3_num_modules=1, + stage3_num_blocks=[2, 2, 2], + stage3_num_channels=[16, 32, 64], + stage4_num_modules=1, + stage4_num_blocks=[2, 2, 2, 2], + stage4_num_channels=[16, 32, 64, 128], + **kwargs) + return model + + +@manager.BACKBONES.add_component +def HRNet_W18_Small_V2(**kwargs): + model = HRNet( + stage1_num_modules=1, + stage1_num_blocks=[2], + stage1_num_channels=[64], + stage2_num_modules=1, + stage2_num_blocks=[2, 2], + stage2_num_channels=[18, 36], + stage3_num_modules=3, + stage3_num_blocks=[2, 2, 2], + stage3_num_channels=[18, 36, 72], + stage4_num_modules=2, + stage4_num_blocks=[2, 2, 2, 2], + stage4_num_channels=[18, 36, 72, 144], + **kwargs) + return model + + +@manager.BACKBONES.add_component +def HRNet_W18(**kwargs): + model = HRNet( + stage1_num_modules=1, + stage1_num_blocks=[4], + stage1_num_channels=[64], + stage2_num_modules=1, + stage2_num_blocks=[4, 4], + stage2_num_channels=[18, 36], + stage3_num_modules=4, + stage3_num_blocks=[4, 4, 4], + stage3_num_channels=[18, 36, 72], + stage4_num_modules=3, + stage4_num_blocks=[4, 4, 4, 4], + stage4_num_channels=[18, 36, 72, 144], + **kwargs) + return model + + +@manager.BACKBONES.add_component +def HRNet_W30(**kwargs): + model = HRNet( + stage1_num_modules=1, + stage1_num_blocks=[4], + stage1_num_channels=[64], + stage2_num_modules=1, + stage2_num_blocks=[4, 4], + stage2_num_channels=[30, 60], + stage3_num_modules=4, + stage3_num_blocks=[4, 4, 4], + stage3_num_channels=[30, 60, 120], + stage4_num_modules=3, + stage4_num_blocks=[4, 4, 4, 4], + stage4_num_channels=[30, 60, 120, 240], + **kwargs) + return model + + +@manager.BACKBONES.add_component +def HRNet_W32(**kwargs): + model = HRNet( + stage1_num_modules=1, + stage1_num_blocks=[4], + stage1_num_channels=[64], + stage2_num_modules=1, + stage2_num_blocks=[4, 4], + stage2_num_channels=[32, 64], + stage3_num_modules=4, + stage3_num_blocks=[4, 4, 4], + stage3_num_channels=[32, 64, 128], + stage4_num_modules=3, + stage4_num_blocks=[4, 4, 4, 4], + stage4_num_channels=[32, 64, 128, 256], + **kwargs) + return model + + +@manager.BACKBONES.add_component +def HRNet_W40(**kwargs): + model = HRNet( + stage1_num_modules=1, + stage1_num_blocks=[4], + stage1_num_channels=[64], + stage2_num_modules=1, + stage2_num_blocks=[4, 4], + stage2_num_channels=[40, 80], + stage3_num_modules=4, + stage3_num_blocks=[4, 4, 4], + stage3_num_channels=[40, 80, 160], + stage4_num_modules=3, + stage4_num_blocks=[4, 4, 4, 4], + stage4_num_channels=[40, 80, 160, 320], + **kwargs) + return model + + +@manager.BACKBONES.add_component +def HRNet_W44(**kwargs): + model = HRNet( + stage1_num_modules=1, + stage1_num_blocks=[4], + stage1_num_channels=[64], + stage2_num_modules=1, + stage2_num_blocks=[4, 4], + stage2_num_channels=[44, 88], + stage3_num_modules=4, + stage3_num_blocks=[4, 4, 4], + stage3_num_channels=[44, 88, 176], + stage4_num_modules=3, + stage4_num_blocks=[4, 4, 4, 4], + stage4_num_channels=[44, 88, 176, 352], + **kwargs) + return model + + +@manager.BACKBONES.add_component +def HRNet_W48(**kwargs): + model = HRNet( + stage1_num_modules=1, + stage1_num_blocks=[4], + stage1_num_channels=[64], + stage2_num_modules=1, + stage2_num_blocks=[4, 4], + stage2_num_channels=[48, 96], + stage3_num_modules=4, + stage3_num_blocks=[4, 4, 4], + stage3_num_channels=[48, 96, 192], + stage4_num_modules=3, + stage4_num_blocks=[4, 4, 4, 4], + stage4_num_channels=[48, 96, 192, 384], + **kwargs) + return model + + +@manager.BACKBONES.add_component +def HRNet_W60(**kwargs): + model = HRNet( + stage1_num_modules=1, + stage1_num_blocks=[4], + stage1_num_channels=[64], + stage2_num_modules=1, + stage2_num_blocks=[4, 4], + stage2_num_channels=[60, 120], + stage3_num_modules=4, + stage3_num_blocks=[4, 4, 4], + stage3_num_channels=[60, 120, 240], + stage4_num_modules=3, + stage4_num_blocks=[4, 4, 4, 4], + stage4_num_channels=[60, 120, 240, 480], + **kwargs) + return model + + +@manager.BACKBONES.add_component +def HRNet_W64(**kwargs): + model = HRNet( + stage1_num_modules=1, + stage1_num_blocks=[4], + stage1_num_channels=[64], + stage2_num_modules=1, + stage2_num_blocks=[4, 4], + stage2_num_channels=[64, 128], + stage3_num_modules=4, + stage3_num_blocks=[4, 4, 4], + stage3_num_channels=[64, 128, 256], + stage4_num_modules=3, + stage4_num_blocks=[4, 4, 4, 4], + stage4_num_channels=[64, 128, 256, 512], + **kwargs) + return model diff --git a/ppmatting/models/backbone/mobilenet_v2.py b/ppmatting/models/backbone/mobilenet_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..c7b947d78797b856f7628e19361f1e2f4261b6cd --- /dev/null +++ b/ppmatting/models/backbone/mobilenet_v2.py @@ -0,0 +1,242 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import numpy as np +import paddle +from paddle import ParamAttr +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.nn import Conv2D, BatchNorm, Linear, Dropout +from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D + +from paddleseg import utils +from paddleseg.cvlibs import manager + +MODEL_URLS = { + "MobileNetV2_x0_25": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV2_x0_25_pretrained.pdparams", + "MobileNetV2_x0_5": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV2_x0_5_pretrained.pdparams", + "MobileNetV2_x0_75": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV2_x0_75_pretrained.pdparams", + "MobileNetV2": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV2_pretrained.pdparams", + "MobileNetV2_x1_5": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV2_x1_5_pretrained.pdparams", + "MobileNetV2_x2_0": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV2_x2_0_pretrained.pdparams" +} + +__all__ = ["MobileNetV2"] + + +class ConvBNLayer(nn.Layer): + def __init__(self, + num_channels, + filter_size, + num_filters, + stride, + padding, + channels=None, + num_groups=1, + name=None, + use_cudnn=True): + super(ConvBNLayer, self).__init__() + + self._conv = Conv2D( + in_channels=num_channels, + out_channels=num_filters, + kernel_size=filter_size, + stride=stride, + padding=padding, + groups=num_groups, + weight_attr=ParamAttr(name=name + "_weights"), + bias_attr=False) + + self._batch_norm = BatchNorm( + num_filters, + param_attr=ParamAttr(name=name + "_bn_scale"), + bias_attr=ParamAttr(name=name + "_bn_offset"), + moving_mean_name=name + "_bn_mean", + moving_variance_name=name + "_bn_variance") + + def forward(self, inputs, if_act=True): + y = self._conv(inputs) + y = self._batch_norm(y) + if if_act: + y = F.relu6(y) + return y + + +class InvertedResidualUnit(nn.Layer): + def __init__(self, num_channels, num_in_filter, num_filters, stride, + filter_size, padding, expansion_factor, name): + super(InvertedResidualUnit, self).__init__() + num_expfilter = int(round(num_in_filter * expansion_factor)) + self._expand_conv = ConvBNLayer( + num_channels=num_channels, + num_filters=num_expfilter, + filter_size=1, + stride=1, + padding=0, + num_groups=1, + name=name + "_expand") + + self._bottleneck_conv = ConvBNLayer( + num_channels=num_expfilter, + num_filters=num_expfilter, + filter_size=filter_size, + stride=stride, + padding=padding, + num_groups=num_expfilter, + use_cudnn=False, + name=name + "_dwise") + + self._linear_conv = ConvBNLayer( + num_channels=num_expfilter, + num_filters=num_filters, + filter_size=1, + stride=1, + padding=0, + num_groups=1, + name=name + "_linear") + + def forward(self, inputs, ifshortcut): + y = self._expand_conv(inputs, if_act=True) + y = self._bottleneck_conv(y, if_act=True) + y = self._linear_conv(y, if_act=False) + if ifshortcut: + y = paddle.add(inputs, y) + return y + + +class InvresiBlocks(nn.Layer): + def __init__(self, in_c, t, c, n, s, name): + super(InvresiBlocks, self).__init__() + + self._first_block = InvertedResidualUnit( + num_channels=in_c, + num_in_filter=in_c, + num_filters=c, + stride=s, + filter_size=3, + padding=1, + expansion_factor=t, + name=name + "_1") + + self._block_list = [] + for i in range(1, n): + block = self.add_sublayer( + name + "_" + str(i + 1), + sublayer=InvertedResidualUnit( + num_channels=c, + num_in_filter=c, + num_filters=c, + stride=1, + filter_size=3, + padding=1, + expansion_factor=t, + name=name + "_" + str(i + 1))) + self._block_list.append(block) + + def forward(self, inputs): + y = self._first_block(inputs, ifshortcut=False) + for block in self._block_list: + y = block(y, ifshortcut=True) + return y + + +@manager.BACKBONES.add_component +class MobileNet(nn.Layer): + def __init__(self, + input_channels=3, + scale=1.0, + pretrained=None, + prefix_name=""): + super(MobileNet, self).__init__() + self.scale = scale + + bottleneck_params_list = [ + (1, 16, 1, 1), + (6, 24, 2, 2), + (6, 32, 3, 2), + (6, 64, 4, 2), + (6, 96, 3, 1), + (6, 160, 3, 2), + (6, 320, 1, 1), + ] + + self.conv1 = ConvBNLayer( + num_channels=input_channels, + num_filters=int(32 * scale), + filter_size=3, + stride=2, + padding=1, + name=prefix_name + "conv1_1") + + self.block_list = [] + i = 1 + in_c = int(32 * scale) + for layer_setting in bottleneck_params_list: + t, c, n, s = layer_setting + i += 1 + block = self.add_sublayer( + prefix_name + "conv" + str(i), + sublayer=InvresiBlocks( + in_c=in_c, + t=t, + c=int(c * scale), + n=n, + s=s, + name=prefix_name + "conv" + str(i))) + self.block_list.append(block) + in_c = int(c * scale) + + self.out_c = int(1280 * scale) if scale > 1.0 else 1280 + self.conv9 = ConvBNLayer( + num_channels=in_c, + num_filters=self.out_c, + filter_size=1, + stride=1, + padding=0, + name=prefix_name + "conv9") + + self.feat_channels = [int(i * scale) for i in [16, 24, 32, 96, 1280]] + self.pretrained = pretrained + self.init_weight() + + def forward(self, inputs): + feat_list = [] + y = self.conv1(inputs, if_act=True) + + block_index = 0 + for block in self.block_list: + y = block(y) + if block_index in [0, 1, 2, 4]: + feat_list.append(y) + block_index += 1 + y = self.conv9(y, if_act=True) + feat_list.append(y) + return feat_list + + def init_weight(self): + utils.load_pretrained_model(self, self.pretrained) + + +@manager.BACKBONES.add_component +def MobileNetV2(**kwargs): + model = MobileNet(scale=1.0, **kwargs) + return model diff --git a/ppmatting/models/backbone/resnet_vd.py b/ppmatting/models/backbone/resnet_vd.py new file mode 100644 index 0000000000000000000000000000000000000000..0fdd9a57664ad80ee59846060cd7f768f757feae --- /dev/null +++ b/ppmatting/models/backbone/resnet_vd.py @@ -0,0 +1,368 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from paddleseg.cvlibs import manager +from paddleseg.models import layers +from paddleseg.utils import utils + +__all__ = [ + "ResNet18_vd", "ResNet34_vd", "ResNet50_vd", "ResNet101_vd", "ResNet152_vd" +] + + +class ConvBNLayer(nn.Layer): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + dilation=1, + groups=1, + is_vd_mode=False, + act=None, ): + super(ConvBNLayer, self).__init__() + + self.is_vd_mode = is_vd_mode + self._pool2d_avg = nn.AvgPool2D( + kernel_size=2, stride=2, padding=0, ceil_mode=True) + self._conv = nn.Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=(kernel_size - 1) // 2 if dilation == 1 else 0, + dilation=dilation, + groups=groups, + bias_attr=False) + + self._batch_norm = layers.SyncBatchNorm(out_channels) + self._act_op = layers.Activation(act=act) + + def forward(self, inputs): + if self.is_vd_mode: + inputs = self._pool2d_avg(inputs) + y = self._conv(inputs) + y = self._batch_norm(y) + y = self._act_op(y) + + return y + + +class BottleneckBlock(nn.Layer): + def __init__(self, + in_channels, + out_channels, + stride, + shortcut=True, + if_first=False, + dilation=1): + super(BottleneckBlock, self).__init__() + + self.conv0 = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + act='relu') + + self.dilation = dilation + + self.conv1 = ConvBNLayer( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride, + act='relu', + dilation=dilation) + self.conv2 = ConvBNLayer( + in_channels=out_channels, + out_channels=out_channels * 4, + kernel_size=1, + act=None) + + if not shortcut: + self.short = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels * 4, + kernel_size=1, + stride=1, + is_vd_mode=False if if_first or stride == 1 else True) + + self.shortcut = shortcut + + def forward(self, inputs): + y = self.conv0(inputs) + + #################################################################### + # If given dilation rate > 1, using corresponding padding. + # The performance drops down without the follow padding. + if self.dilation > 1: + padding = self.dilation + y = F.pad(y, [padding, padding, padding, padding]) + ##################################################################### + + conv1 = self.conv1(y) + conv2 = self.conv2(conv1) + + if self.shortcut: + short = inputs + else: + short = self.short(inputs) + + y = paddle.add(x=short, y=conv2) + y = F.relu(y) + return y + + +class BasicBlock(nn.Layer): + def __init__(self, + in_channels, + out_channels, + stride, + shortcut=True, + if_first=False): + super(BasicBlock, self).__init__() + self.stride = stride + self.conv0 = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride, + act='relu') + self.conv1 = ConvBNLayer( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + act=None) + + if not shortcut: + self.short = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + is_vd_mode=False if if_first or stride == 1 else True) + + self.shortcut = shortcut + + def forward(self, inputs): + y = self.conv0(inputs) + conv1 = self.conv1(y) + + if self.shortcut: + short = inputs + else: + short = self.short(inputs) + y = paddle.add(x=short, y=conv1) + y = F.relu(y) + + return y + + +class ResNet_vd(nn.Layer): + """ + The ResNet_vd implementation based on PaddlePaddle. + + The original article refers to Jingdong + Tong He, et, al. "Bag of Tricks for Image Classification with Convolutional Neural Networks" + (https://arxiv.org/pdf/1812.01187.pdf). + + Args: + layers (int, optional): The layers of ResNet_vd. The supported layers are (18, 34, 50, 101, 152, 200). Default: 50. + output_stride (int, optional): The stride of output features compared to input images. It is 8 or 16. Default: 8. + multi_grid (tuple|list, optional): The grid of stage4. Defult: (1, 1, 1). + pretrained (str, optional): The path of pretrained model. + + """ + + def __init__(self, + input_channels=3, + layers=50, + output_stride=32, + multi_grid=(1, 1, 1), + pretrained=None): + super(ResNet_vd, self).__init__() + + self.conv1_logit = None # for gscnn shape stream + self.layers = layers + supported_layers = [18, 34, 50, 101, 152, 200] + assert layers in supported_layers, \ + "supported layers are {} but input layer is {}".format( + supported_layers, layers) + + if layers == 18: + depth = [2, 2, 2, 2] + elif layers == 34 or layers == 50: + depth = [3, 4, 6, 3] + elif layers == 101: + depth = [3, 4, 23, 3] + elif layers == 152: + depth = [3, 8, 36, 3] + elif layers == 200: + depth = [3, 12, 48, 3] + num_channels = [64, 256, 512, + 1024] if layers >= 50 else [64, 64, 128, 256] + num_filters = [64, 128, 256, 512] + + # for channels of four returned stages + self.feat_channels = [c * 4 for c in num_filters + ] if layers >= 50 else num_filters + self.feat_channels = [64] + self.feat_channels + + dilation_dict = None + if output_stride == 8: + dilation_dict = {2: 2, 3: 4} + elif output_stride == 16: + dilation_dict = {3: 2} + + self.conv1_1 = ConvBNLayer( + in_channels=input_channels, + out_channels=32, + kernel_size=3, + stride=2, + act='relu') + self.conv1_2 = ConvBNLayer( + in_channels=32, + out_channels=32, + kernel_size=3, + stride=1, + act='relu') + self.conv1_3 = ConvBNLayer( + in_channels=32, + out_channels=64, + kernel_size=3, + stride=1, + act='relu') + self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1) + + # self.block_list = [] + self.stage_list = [] + if layers >= 50: + for block in range(len(depth)): + shortcut = False + block_list = [] + for i in range(depth[block]): + if layers in [101, 152] and block == 2: + if i == 0: + conv_name = "res" + str(block + 2) + "a" + else: + conv_name = "res" + str(block + 2) + "b" + str(i) + else: + conv_name = "res" + str(block + 2) + chr(97 + i) + + ############################################################################### + # Add dilation rate for some segmentation tasks, if dilation_dict is not None. + dilation_rate = dilation_dict[ + block] if dilation_dict and block in dilation_dict else 1 + + # Actually block here is 'stage', and i is 'block' in 'stage' + # At the stage 4, expand the the dilation_rate if given multi_grid + if block == 3: + dilation_rate = dilation_rate * multi_grid[i] + ############################################################################### + + bottleneck_block = self.add_sublayer( + 'bb_%d_%d' % (block, i), + BottleneckBlock( + in_channels=num_channels[block] + if i == 0 else num_filters[block] * 4, + out_channels=num_filters[block], + stride=2 if i == 0 and block != 0 and + dilation_rate == 1 else 1, + shortcut=shortcut, + if_first=block == i == 0, + dilation=dilation_rate)) + + block_list.append(bottleneck_block) + shortcut = True + self.stage_list.append(block_list) + else: + for block in range(len(depth)): + shortcut = False + block_list = [] + for i in range(depth[block]): + conv_name = "res" + str(block + 2) + chr(97 + i) + basic_block = self.add_sublayer( + 'bb_%d_%d' % (block, i), + BasicBlock( + in_channels=num_channels[block] + if i == 0 else num_filters[block], + out_channels=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + shortcut=shortcut, + if_first=block == i == 0)) + block_list.append(basic_block) + shortcut = True + self.stage_list.append(block_list) + + self.pretrained = pretrained + self.init_weight() + + def forward(self, inputs): + feat_list = [] + y = self.conv1_1(inputs) + y = self.conv1_2(y) + y = self.conv1_3(y) + feat_list.append(y) + + y = self.pool2d_max(y) + + # A feature list saves the output feature map of each stage. + for stage in self.stage_list: + for block in stage: + y = block(y) + feat_list.append(y) + + return feat_list + + def init_weight(self): + utils.load_pretrained_model(self, self.pretrained) + + +@manager.BACKBONES.add_component +def ResNet18_vd(**args): + model = ResNet_vd(layers=18, **args) + return model + + +@manager.BACKBONES.add_component +def ResNet34_vd(**args): + model = ResNet_vd(layers=34, **args) + return model + + +@manager.BACKBONES.add_component +def ResNet50_vd(**args): + model = ResNet_vd(layers=50, **args) + return model + + +@manager.BACKBONES.add_component +def ResNet101_vd(**args): + model = ResNet_vd(layers=101, **args) + return model + + +def ResNet152_vd(**args): + model = ResNet_vd(layers=152, **args) + return model + + +def ResNet200_vd(**args): + model = ResNet_vd(layers=200, **args) + return model diff --git a/ppmatting/models/backbone/vgg.py b/ppmatting/models/backbone/vgg.py new file mode 100644 index 0000000000000000000000000000000000000000..64b529bf0c3e25cb82ea4b4c31bec9ef30d2da59 --- /dev/null +++ b/ppmatting/models/backbone/vgg.py @@ -0,0 +1,166 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle import ParamAttr +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.nn import Conv2D, BatchNorm, Linear, Dropout +from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D + +from paddleseg.cvlibs import manager +from paddleseg.utils import utils + + +class ConvBlock(nn.Layer): + def __init__(self, input_channels, output_channels, groups, name=None): + super(ConvBlock, self).__init__() + + self.groups = groups + self._conv_1 = Conv2D( + in_channels=input_channels, + out_channels=output_channels, + kernel_size=3, + stride=1, + padding=1, + weight_attr=ParamAttr(name=name + "1_weights"), + bias_attr=False) + if groups == 2 or groups == 3 or groups == 4: + self._conv_2 = Conv2D( + in_channels=output_channels, + out_channels=output_channels, + kernel_size=3, + stride=1, + padding=1, + weight_attr=ParamAttr(name=name + "2_weights"), + bias_attr=False) + if groups == 3 or groups == 4: + self._conv_3 = Conv2D( + in_channels=output_channels, + out_channels=output_channels, + kernel_size=3, + stride=1, + padding=1, + weight_attr=ParamAttr(name=name + "3_weights"), + bias_attr=False) + if groups == 4: + self._conv_4 = Conv2D( + in_channels=output_channels, + out_channels=output_channels, + kernel_size=3, + stride=1, + padding=1, + weight_attr=ParamAttr(name=name + "4_weights"), + bias_attr=False) + + self._pool = MaxPool2D( + kernel_size=2, stride=2, padding=0, return_mask=True) + + def forward(self, inputs): + x = self._conv_1(inputs) + x = F.relu(x) + if self.groups == 2 or self.groups == 3 or self.groups == 4: + x = self._conv_2(x) + x = F.relu(x) + if self.groups == 3 or self.groups == 4: + x = self._conv_3(x) + x = F.relu(x) + if self.groups == 4: + x = self._conv_4(x) + x = F.relu(x) + skip = x + x, max_indices = self._pool(x) + return x, max_indices, skip + + +class VGGNet(nn.Layer): + def __init__(self, input_channels=3, layers=11, pretrained=None): + super(VGGNet, self).__init__() + self.pretrained = pretrained + + self.layers = layers + self.vgg_configure = { + 11: [1, 1, 2, 2, 2], + 13: [2, 2, 2, 2, 2], + 16: [2, 2, 3, 3, 3], + 19: [2, 2, 4, 4, 4] + } + assert self.layers in self.vgg_configure.keys(), \ + "supported layers are {} but input layer is {}".format( + self.vgg_configure.keys(), layers) + self.groups = self.vgg_configure[self.layers] + + # matting的第一层卷积输入为4通道,初始化是直接初始化为0 + self._conv_block_1 = ConvBlock( + input_channels, 64, self.groups[0], name="conv1_") + self._conv_block_2 = ConvBlock(64, 128, self.groups[1], name="conv2_") + self._conv_block_3 = ConvBlock(128, 256, self.groups[2], name="conv3_") + self._conv_block_4 = ConvBlock(256, 512, self.groups[3], name="conv4_") + self._conv_block_5 = ConvBlock(512, 512, self.groups[4], name="conv5_") + + # 这一层的初始化需要利用vgg fc6的参数转换后进行初始化,可以暂时不考虑初始化 + self._conv_6 = Conv2D( + 512, 512, kernel_size=3, padding=1, bias_attr=False) + + self.init_weight() + + def forward(self, inputs): + fea_list = [] + ids_list = [] + x, ids, skip = self._conv_block_1(inputs) + fea_list.append(skip) + ids_list.append(ids) + x, ids, skip = self._conv_block_2(x) + fea_list.append(skip) + ids_list.append(ids) + x, ids, skip = self._conv_block_3(x) + fea_list.append(skip) + ids_list.append(ids) + x, ids, skip = self._conv_block_4(x) + fea_list.append(skip) + ids_list.append(ids) + x, ids, skip = self._conv_block_5(x) + fea_list.append(skip) + ids_list.append(ids) + x = F.relu(self._conv_6(x)) + fea_list.append(x) + return fea_list + + def init_weight(self): + if self.pretrained is not None: + utils.load_pretrained_model(self, self.pretrained) + + +@manager.BACKBONES.add_component +def VGG11(**args): + model = VGGNet(layers=11, **args) + return model + + +@manager.BACKBONES.add_component +def VGG13(**args): + model = VGGNet(layers=13, **args) + return model + + +@manager.BACKBONES.add_component +def VGG16(**args): + model = VGGNet(layers=16, **args) + return model + + +@manager.BACKBONES.add_component +def VGG19(**args): + model = VGGNet(layers=19, **args) + return model diff --git a/ppmatting/models/dim.py b/ppmatting/models/dim.py new file mode 100644 index 0000000000000000000000000000000000000000..5d9ae654322242f785407e61ff7b8405d6b443b4 --- /dev/null +++ b/ppmatting/models/dim.py @@ -0,0 +1,208 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddleseg.models import layers +from paddleseg import utils +from paddleseg.cvlibs import manager + +from ppmatting.models.losses import MRSD + + +@manager.MODELS.add_component +class DIM(nn.Layer): + """ + The DIM implementation based on PaddlePaddle. + + The original article refers to + Ning Xu, et, al. "Deep Image Matting" + (https://arxiv.org/pdf/1908.07919.pdf). + + Args: + backbone: backbone model. + stage (int, optional): The stage of model. Defautl: 3. + decoder_input_channels(int, optional): The channel of decoder input. Default: 512. + pretrained(str, optional): The path of pretrianed model. Defautl: None. + + """ + + def __init__(self, + backbone, + stage=3, + decoder_input_channels=512, + pretrained=None): + super().__init__() + self.backbone = backbone + self.pretrained = pretrained + self.stage = stage + self.loss_func_dict = None + + decoder_output_channels = [64, 128, 256, 512] + self.decoder = Decoder( + input_channels=decoder_input_channels, + output_channels=decoder_output_channels) + if self.stage == 2: + for param in self.backbone.parameters(): + param.stop_gradient = True + for param in self.decoder.parameters(): + param.stop_gradient = True + if self.stage >= 2: + self.refine = Refine() + self.init_weight() + + def forward(self, inputs): + input_shape = paddle.shape(inputs['img'])[-2:] + x = paddle.concat([inputs['img'], inputs['trimap'] / 255], axis=1) + fea_list = self.backbone(x) + + # decoder stage + up_shape = [] + for i in range(5): + up_shape.append(paddle.shape(fea_list[i])[-2:]) + alpha_raw = self.decoder(fea_list, up_shape) + alpha_raw = F.interpolate( + alpha_raw, input_shape, mode='bilinear', align_corners=False) + logit_dict = {'alpha_raw': alpha_raw} + if self.stage < 2: + return logit_dict + + if self.stage >= 2: + # refine stage + refine_input = paddle.concat([inputs['img'], alpha_raw], axis=1) + alpha_refine = self.refine(refine_input) + + # finally alpha + alpha_pred = alpha_refine + alpha_raw + alpha_pred = F.interpolate( + alpha_pred, input_shape, mode='bilinear', align_corners=False) + if not self.training: + alpha_pred = paddle.clip(alpha_pred, min=0, max=1) + logit_dict['alpha_pred'] = alpha_pred + if self.training: + loss_dict = self.loss(logit_dict, inputs) + return logit_dict, loss_dict + else: + return alpha_pred + + def loss(self, logit_dict, label_dict, loss_func_dict=None): + if loss_func_dict is None: + if self.loss_func_dict is None: + self.loss_func_dict = defaultdict(list) + self.loss_func_dict['alpha_raw'].append(MRSD()) + self.loss_func_dict['comp'].append(MRSD()) + self.loss_func_dict['alpha_pred'].append(MRSD()) + else: + self.loss_func_dict = loss_func_dict + + loss = {} + mask = label_dict['trimap'] == 128 + loss['all'] = 0 + + if self.stage != 2: + loss['alpha_raw'] = self.loss_func_dict['alpha_raw'][0]( + logit_dict['alpha_raw'], label_dict['alpha'], mask) + loss['alpha_raw'] = 0.5 * loss['alpha_raw'] + loss['all'] = loss['all'] + loss['alpha_raw'] + + if self.stage == 1 or self.stage == 3: + comp_pred = logit_dict['alpha_raw'] * label_dict['fg'] + \ + (1 - logit_dict['alpha_raw']) * label_dict['bg'] + loss['comp'] = self.loss_func_dict['comp'][0]( + comp_pred, label_dict['img'], mask) + loss['comp'] = 0.5 * loss['comp'] + loss['all'] = loss['all'] + loss['comp'] + + if self.stage == 2 or self.stage == 3: + loss['alpha_pred'] = self.loss_func_dict['alpha_pred'][0]( + logit_dict['alpha_pred'], label_dict['alpha'], mask) + loss['all'] = loss['all'] + loss['alpha_pred'] + + return loss + + def init_weight(self): + if self.pretrained is not None: + utils.load_entire_model(self, self.pretrained) + + +# bilinear interpolate skip connect +class Up(nn.Layer): + def __init__(self, input_channels, output_channels): + super().__init__() + self.conv = layers.ConvBNReLU( + input_channels, + output_channels, + kernel_size=5, + padding=2, + bias_attr=False) + + def forward(self, x, skip, output_shape): + x = F.interpolate( + x, size=output_shape, mode='bilinear', align_corners=False) + x = x + skip + x = self.conv(x) + x = F.relu(x) + + return x + + +class Decoder(nn.Layer): + def __init__(self, input_channels, output_channels=(64, 128, 256, 512)): + super().__init__() + self.deconv6 = nn.Conv2D( + input_channels, input_channels, kernel_size=1, bias_attr=False) + self.deconv5 = Up(input_channels, output_channels[-1]) + self.deconv4 = Up(output_channels[-1], output_channels[-2]) + self.deconv3 = Up(output_channels[-2], output_channels[-3]) + self.deconv2 = Up(output_channels[-3], output_channels[-4]) + self.deconv1 = Up(output_channels[-4], 64) + + self.alpha_conv = nn.Conv2D( + 64, 1, kernel_size=5, padding=2, bias_attr=False) + + def forward(self, fea_list, shape_list): + x = fea_list[-1] + x = self.deconv6(x) + x = self.deconv5(x, fea_list[4], shape_list[4]) + x = self.deconv4(x, fea_list[3], shape_list[3]) + x = self.deconv3(x, fea_list[2], shape_list[2]) + x = self.deconv2(x, fea_list[1], shape_list[1]) + x = self.deconv1(x, fea_list[0], shape_list[0]) + alpha = self.alpha_conv(x) + alpha = F.sigmoid(alpha) + + return alpha + + +class Refine(nn.Layer): + def __init__(self): + super().__init__() + self.conv1 = layers.ConvBNReLU( + 4, 64, kernel_size=3, padding=1, bias_attr=False) + self.conv2 = layers.ConvBNReLU( + 64, 64, kernel_size=3, padding=1, bias_attr=False) + self.conv3 = layers.ConvBNReLU( + 64, 64, kernel_size=3, padding=1, bias_attr=False) + self.alpha_pred = layers.ConvBNReLU( + 64, 1, kernel_size=3, padding=1, bias_attr=False) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + alpha = self.alpha_pred(x) + + return alpha diff --git a/ppmatting/models/gca.py b/ppmatting/models/gca.py new file mode 100644 index 0000000000000000000000000000000000000000..369a913570682f85ea696beaf3b78b7c2ec88141 --- /dev/null +++ b/ppmatting/models/gca.py @@ -0,0 +1,305 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# The gca code was heavily based on https://github.com/Yaoyi-Li/GCA-Matting +# and https://github.com/open-mmlab/mmediting + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddleseg.models import layers +from paddleseg import utils +from paddleseg.cvlibs import manager, param_init + +from ppmatting.models.layers import GuidedCxtAtten + + +@manager.MODELS.add_component +class GCABaseline(nn.Layer): + def __init__(self, backbone, pretrained=None): + super().__init__() + self.encoder = backbone + self.decoder = ResShortCut_D_Dec([2, 3, 3, 2]) + + def forward(self, inputs): + + x = paddle.concat([inputs['img'], inputs['trimap'] / 255], axis=1) + embedding, mid_fea = self.encoder(x) + alpha_pred = self.decoder(embedding, mid_fea) + + if self.training: + logit_dict = {'alpha_pred': alpha_pred, } + loss_dict = {} + alpha_gt = inputs['alpha'] + loss_dict["alpha"] = F.l1_loss(alpha_pred, alpha_gt) + loss_dict["all"] = loss_dict["alpha"] + return logit_dict, loss_dict + + return alpha_pred + + +@manager.MODELS.add_component +class GCA(GCABaseline): + def __init__(self, backbone, pretrained=None): + super().__init__(backbone, pretrained) + self.decoder = ResGuidedCxtAtten_Dec([2, 3, 3, 2]) + + +def conv5x5(in_planes, out_planes, stride=1, groups=1, dilation=1): + """5x5 convolution with padding""" + return nn.Conv2D( + in_planes, + out_planes, + kernel_size=5, + stride=stride, + padding=2, + groups=groups, + bias_attr=False, + dilation=dilation) + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2D( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias_attr=False, + dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2D( + in_planes, out_planes, kernel_size=1, stride=stride, bias_attr=False) + + +class BasicBlock(nn.Layer): + expansion = 1 + + def __init__(self, + inplanes, + planes, + stride=1, + upsample=None, + norm_layer=None, + large_kernel=False): + super().__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm + self.stride = stride + conv = conv5x5 if large_kernel else conv3x3 + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + if self.stride > 1: + self.conv1 = nn.utils.spectral_norm( + nn.Conv2DTranspose( + inplanes, + inplanes, + kernel_size=4, + stride=2, + padding=1, + bias_attr=False)) + else: + self.conv1 = nn.utils.spectral_norm(conv(inplanes, inplanes)) + self.bn1 = norm_layer(inplanes) + self.activation = nn.LeakyReLU(0.2) + self.conv2 = nn.utils.spectral_norm(conv(inplanes, planes)) + self.bn2 = norm_layer(planes) + self.upsample = upsample + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.activation(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.upsample is not None: + identity = self.upsample(x) + + out += identity + out = self.activation(out) + + return out + + +class ResNet_D_Dec(nn.Layer): + def __init__(self, + layers=[3, 4, 4, 2], + norm_layer=None, + large_kernel=False, + late_downsample=False): + super().__init__() + + if norm_layer is None: + norm_layer = nn.BatchNorm + self._norm_layer = norm_layer + self.large_kernel = large_kernel + self.kernel_size = 5 if self.large_kernel else 3 + + self.inplanes = 512 if layers[0] > 0 else 256 + self.late_downsample = late_downsample + self.midplanes = 64 if late_downsample else 32 + + self.conv1 = nn.utils.spectral_norm( + nn.Conv2DTranspose( + self.midplanes, + 32, + kernel_size=4, + stride=2, + padding=1, + bias_attr=False)) + self.bn1 = norm_layer(32) + self.leaky_relu = nn.LeakyReLU(0.2) + self.conv2 = nn.Conv2D( + 32, + 1, + kernel_size=self.kernel_size, + stride=1, + padding=self.kernel_size // 2) + self.upsample = nn.UpsamplingNearest2D(scale_factor=2) + self.tanh = nn.Tanh() + self.layer1 = self._make_layer(BasicBlock, 256, layers[0], stride=2) + self.layer2 = self._make_layer(BasicBlock, 128, layers[1], stride=2) + self.layer3 = self._make_layer(BasicBlock, 64, layers[2], stride=2) + self.layer4 = self._make_layer( + BasicBlock, self.midplanes, layers[3], stride=2) + + self.init_weight() + + def _make_layer(self, block, planes, blocks, stride=1): + if blocks == 0: + return nn.Sequential(nn.Identity()) + norm_layer = self._norm_layer + upsample = None + if stride != 1: + upsample = nn.Sequential( + nn.UpsamplingNearest2D(scale_factor=2), + nn.utils.spectral_norm( + conv1x1(self.inplanes, planes * block.expansion)), + norm_layer(planes * block.expansion), ) + elif self.inplanes != planes * block.expansion: + upsample = nn.Sequential( + nn.utils.spectral_norm( + conv1x1(self.inplanes, planes * block.expansion)), + norm_layer(planes * block.expansion), ) + + layers = [ + block(self.inplanes, planes, stride, upsample, norm_layer, + self.large_kernel) + ] + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + norm_layer=norm_layer, + large_kernel=self.large_kernel)) + + return nn.Sequential(*layers) + + def forward(self, x, mid_fea): + x = self.layer1(x) # N x 256 x 32 x 32 + print(x.shape) + x = self.layer2(x) # N x 128 x 64 x 64 + print(x.shape) + x = self.layer3(x) # N x 64 x 128 x 128 + print(x.shape) + x = self.layer4(x) # N x 32 x 256 x 256 + print(x.shape) + x = self.conv1(x) + x = self.bn1(x) + x = self.leaky_relu(x) + x = self.conv2(x) + + alpha = (self.tanh(x) + 1.0) / 2.0 + + return alpha + + def init_weight(self): + for layer in self.sublayers(): + if isinstance(layer, nn.Conv2D): + + if hasattr(layer, "weight_orig"): + param = layer.weight_orig + else: + param = layer.weight + param_init.xavier_uniform(param) + + elif isinstance(layer, (nn.BatchNorm, nn.SyncBatchNorm)): + param_init.constant_init(layer.weight, value=1.0) + param_init.constant_init(layer.bias, value=0.0) + + elif isinstance(layer, BasicBlock): + param_init.constant_init(layer.bn2.weight, value=0.0) + + +class ResShortCut_D_Dec(ResNet_D_Dec): + def __init__(self, + layers, + norm_layer=None, + large_kernel=False, + late_downsample=False): + super().__init__( + layers, norm_layer, large_kernel, late_downsample=late_downsample) + + def forward(self, x, mid_fea): + fea1, fea2, fea3, fea4, fea5 = mid_fea['shortcut'] + x = self.layer1(x) + fea5 + x = self.layer2(x) + fea4 + x = self.layer3(x) + fea3 + x = self.layer4(x) + fea2 + x = self.conv1(x) + x = self.bn1(x) + x = self.leaky_relu(x) + fea1 + x = self.conv2(x) + + alpha = (self.tanh(x) + 1.0) / 2.0 + + return alpha + + +class ResGuidedCxtAtten_Dec(ResNet_D_Dec): + def __init__(self, + layers, + norm_layer=None, + large_kernel=False, + late_downsample=False): + super().__init__( + layers, norm_layer, large_kernel, late_downsample=late_downsample) + self.gca = GuidedCxtAtten(128, 128) + + def forward(self, x, mid_fea): + fea1, fea2, fea3, fea4, fea5 = mid_fea['shortcut'] + im = mid_fea['image_fea'] + x = self.layer1(x) + fea5 # N x 256 x 32 x 32 + x = self.layer2(x) + fea4 # N x 128 x 64 x 64 + x = self.gca(im, x, mid_fea['unknown']) # contextual attention + x = self.layer3(x) + fea3 # N x 64 x 128 x 128 + x = self.layer4(x) + fea2 # N x 32 x 256 x 256 + x = self.conv1(x) + x = self.bn1(x) + x = self.leaky_relu(x) + fea1 + x = self.conv2(x) + + alpha = (self.tanh(x) + 1.0) / 2.0 + + return alpha diff --git a/ppmatting/models/human_matting.py b/ppmatting/models/human_matting.py new file mode 100644 index 0000000000000000000000000000000000000000..cf315edfa563fe231a119dd15b749c41157c988c --- /dev/null +++ b/ppmatting/models/human_matting.py @@ -0,0 +1,454 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +import time + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +import paddleseg +from paddleseg.models import layers +from paddleseg import utils +from paddleseg.cvlibs import manager + +from ppmatting.models.losses import MRSD + + +def conv_up_psp(in_channels, out_channels, up_sample): + return nn.Sequential( + layers.ConvBNReLU( + in_channels, out_channels, 3, padding=1), + nn.Upsample( + scale_factor=up_sample, mode='bilinear', align_corners=False)) + + +@manager.MODELS.add_component +class HumanMatting(nn.Layer): + """A model for """ + + def __init__(self, + backbone, + pretrained=None, + backbone_scale=0.25, + refine_kernel_size=3, + if_refine=True): + super().__init__() + if if_refine: + if backbone_scale > 0.5: + raise ValueError( + 'Backbone_scale should not be greater than 1/2, but it is {}' + .format(backbone_scale)) + else: + backbone_scale = 1 + + self.backbone = backbone + self.backbone_scale = backbone_scale + self.pretrained = pretrained + self.if_refine = if_refine + if if_refine: + self.refiner = Refiner(kernel_size=refine_kernel_size) + self.loss_func_dict = None + + self.backbone_channels = backbone.feat_channels + ###################### + ### Decoder part - Glance + ###################### + self.psp_module = layers.PPModule( + self.backbone_channels[-1], + 512, + bin_sizes=(1, 3, 5), + dim_reduction=False, + align_corners=False) + self.psp4 = conv_up_psp(512, 256, 2) + self.psp3 = conv_up_psp(512, 128, 4) + self.psp2 = conv_up_psp(512, 64, 8) + self.psp1 = conv_up_psp(512, 64, 16) + # stage 5g + self.decoder5_g = nn.Sequential( + layers.ConvBNReLU( + 512 + self.backbone_channels[-1], 512, 3, padding=1), + layers.ConvBNReLU( + 512, 512, 3, padding=2, dilation=2), + layers.ConvBNReLU( + 512, 256, 3, padding=2, dilation=2), + nn.Upsample( + scale_factor=2, mode='bilinear', align_corners=False)) + # stage 4g + self.decoder4_g = nn.Sequential( + layers.ConvBNReLU( + 512, 256, 3, padding=1), + layers.ConvBNReLU( + 256, 256, 3, padding=1), + layers.ConvBNReLU( + 256, 128, 3, padding=1), + nn.Upsample( + scale_factor=2, mode='bilinear', align_corners=False)) + # stage 3g + self.decoder3_g = nn.Sequential( + layers.ConvBNReLU( + 256, 128, 3, padding=1), + layers.ConvBNReLU( + 128, 128, 3, padding=1), + layers.ConvBNReLU( + 128, 64, 3, padding=1), + nn.Upsample( + scale_factor=2, mode='bilinear', align_corners=False)) + # stage 2g + self.decoder2_g = nn.Sequential( + layers.ConvBNReLU( + 128, 128, 3, padding=1), + layers.ConvBNReLU( + 128, 128, 3, padding=1), + layers.ConvBNReLU( + 128, 64, 3, padding=1), + nn.Upsample( + scale_factor=2, mode='bilinear', align_corners=False)) + # stage 1g + self.decoder1_g = nn.Sequential( + layers.ConvBNReLU( + 128, 64, 3, padding=1), + layers.ConvBNReLU( + 64, 64, 3, padding=1), + layers.ConvBNReLU( + 64, 64, 3, padding=1), + nn.Upsample( + scale_factor=2, mode='bilinear', align_corners=False)) + # stage 0g + self.decoder0_g = nn.Sequential( + layers.ConvBNReLU( + 64, 64, 3, padding=1), + layers.ConvBNReLU( + 64, 64, 3, padding=1), + nn.Conv2D( + 64, 3, 3, padding=1)) + + ########################## + ### Decoder part - FOCUS + ########################## + self.bridge_block = nn.Sequential( + layers.ConvBNReLU( + self.backbone_channels[-1], 512, 3, dilation=2, padding=2), + layers.ConvBNReLU( + 512, 512, 3, dilation=2, padding=2), + layers.ConvBNReLU( + 512, 512, 3, dilation=2, padding=2)) + # stage 5f + self.decoder5_f = nn.Sequential( + layers.ConvBNReLU( + 512 + self.backbone_channels[-1], 512, 3, padding=1), + layers.ConvBNReLU( + 512, 512, 3, padding=2, dilation=2), + layers.ConvBNReLU( + 512, 256, 3, padding=2, dilation=2), + nn.Upsample( + scale_factor=2, mode='bilinear', align_corners=False)) + # stage 4f + self.decoder4_f = nn.Sequential( + layers.ConvBNReLU( + 256 + self.backbone_channels[-2], 256, 3, padding=1), + layers.ConvBNReLU( + 256, 256, 3, padding=1), + layers.ConvBNReLU( + 256, 128, 3, padding=1), + nn.Upsample( + scale_factor=2, mode='bilinear', align_corners=False)) + # stage 3f + self.decoder3_f = nn.Sequential( + layers.ConvBNReLU( + 128 + self.backbone_channels[-3], 128, 3, padding=1), + layers.ConvBNReLU( + 128, 128, 3, padding=1), + layers.ConvBNReLU( + 128, 64, 3, padding=1), + nn.Upsample( + scale_factor=2, mode='bilinear', align_corners=False)) + # stage 2f + self.decoder2_f = nn.Sequential( + layers.ConvBNReLU( + 64 + self.backbone_channels[-4], 128, 3, padding=1), + layers.ConvBNReLU( + 128, 128, 3, padding=1), + layers.ConvBNReLU( + 128, 64, 3, padding=1), + nn.Upsample( + scale_factor=2, mode='bilinear', align_corners=False)) + # stage 1f + self.decoder1_f = nn.Sequential( + layers.ConvBNReLU( + 64 + self.backbone_channels[-5], 64, 3, padding=1), + layers.ConvBNReLU( + 64, 64, 3, padding=1), + layers.ConvBNReLU( + 64, 64, 3, padding=1), + nn.Upsample( + scale_factor=2, mode='bilinear', align_corners=False)) + # stage 0f + self.decoder0_f = nn.Sequential( + layers.ConvBNReLU( + 64, 64, 3, padding=1), + layers.ConvBNReLU( + 64, 64, 3, padding=1), + nn.Conv2D( + 64, 1 + 1 + 32, 3, padding=1)) + self.init_weight() + + def forward(self, data): + src = data['img'] + src_h, src_w = paddle.shape(src)[2:] + if self.if_refine: + # It is not need when exporting. + if isinstance(src_h, paddle.Tensor): + if (src_h % 4 != 0) or (src_w % 4) != 0: + raise ValueError( + 'The input image must have width and height that are divisible by 4' + ) + + # Downsample src for backbone + src_sm = F.interpolate( + src, + scale_factor=self.backbone_scale, + mode='bilinear', + align_corners=False) + + # Base + fea_list = self.backbone(src_sm) + ########################## + ### Decoder part - GLANCE + ########################## + #psp: N, 512, H/32, W/32 + psp = self.psp_module(fea_list[-1]) + #d6_g: N, 512, H/16, W/16 + d5_g = self.decoder5_g(paddle.concat((psp, fea_list[-1]), 1)) + #d5_g: N, 512, H/8, W/8 + d4_g = self.decoder4_g(paddle.concat((self.psp4(psp), d5_g), 1)) + #d4_g: N, 256, H/4, W/4 + d3_g = self.decoder3_g(paddle.concat((self.psp3(psp), d4_g), 1)) + #d4_g: N, 128, H/2, W/2 + d2_g = self.decoder2_g(paddle.concat((self.psp2(psp), d3_g), 1)) + #d2_g: N, 64, H, W + d1_g = self.decoder1_g(paddle.concat((self.psp1(psp), d2_g), 1)) + #d0_g: N, 3, H, W + d0_g = self.decoder0_g(d1_g) + # The 1st channel is foreground. The 2nd is transition region. The 3rd is background. + # glance_sigmoid = F.sigmoid(d0_g) + glance_sigmoid = F.softmax(d0_g, axis=1) + + ########################## + ### Decoder part - FOCUS + ########################## + bb = self.bridge_block(fea_list[-1]) + #bg: N, 512, H/32, W/32 + d5_f = self.decoder5_f(paddle.concat((bb, fea_list[-1]), 1)) + #d5_f: N, 256, H/16, W/16 + d4_f = self.decoder4_f(paddle.concat((d5_f, fea_list[-2]), 1)) + #d4_f: N, 128, H/8, W/8 + d3_f = self.decoder3_f(paddle.concat((d4_f, fea_list[-3]), 1)) + #d3_f: N, 64, H/4, W/4 + d2_f = self.decoder2_f(paddle.concat((d3_f, fea_list[-4]), 1)) + #d2_f: N, 64, H/2, W/2 + d1_f = self.decoder1_f(paddle.concat((d2_f, fea_list[-5]), 1)) + #d1_f: N, 64, H, W + d0_f = self.decoder0_f(d1_f) + #d0_f: N, 1, H, W + focus_sigmoid = F.sigmoid(d0_f[:, 0:1, :, :]) + pha_sm = self.fusion(glance_sigmoid, focus_sigmoid) + err_sm = d0_f[:, 1:2, :, :] + err_sm = paddle.clip(err_sm, 0., 1.) + hid_sm = F.relu(d0_f[:, 2:, :, :]) + + # Refiner + if self.if_refine: + pha = self.refiner( + src=src, pha=pha_sm, err=err_sm, hid=hid_sm, tri=glance_sigmoid) + # Clamp outputs + pha = paddle.clip(pha, 0., 1.) + + if self.training: + logit_dict = { + 'glance': glance_sigmoid, + 'focus': focus_sigmoid, + 'fusion': pha_sm, + 'error': err_sm + } + if self.if_refine: + logit_dict['refine'] = pha + loss_dict = self.loss(logit_dict, data) + return logit_dict, loss_dict + else: + return pha if self.if_refine else pha_sm + + def loss(self, logit_dict, label_dict, loss_func_dict=None): + if loss_func_dict is None: + if self.loss_func_dict is None: + self.loss_func_dict = defaultdict(list) + self.loss_func_dict['glance'].append(nn.NLLLoss()) + self.loss_func_dict['focus'].append(MRSD()) + self.loss_func_dict['cm'].append(MRSD()) + self.loss_func_dict['err'].append(paddleseg.models.MSELoss()) + self.loss_func_dict['refine'].append(paddleseg.models.L1Loss()) + else: + self.loss_func_dict = loss_func_dict + + loss = {} + + # glance loss computation + # get glance label + glance_label = F.interpolate( + label_dict['trimap'], + logit_dict['glance'].shape[2:], + mode='nearest', + align_corners=False) + glance_label_trans = (glance_label == 128).astype('int64') + glance_label_bg = (glance_label == 0).astype('int64') + glance_label = glance_label_trans + glance_label_bg * 2 + loss_glance = self.loss_func_dict['glance'][0]( + paddle.log(logit_dict['glance'] + 1e-6), glance_label.squeeze(1)) + loss['glance'] = loss_glance + + # focus loss computation + focus_label = F.interpolate( + label_dict['alpha'], + logit_dict['focus'].shape[2:], + mode='bilinear', + align_corners=False) + loss_focus = self.loss_func_dict['focus'][0]( + logit_dict['focus'], focus_label, glance_label_trans) + loss['focus'] = loss_focus + + # collaborative matting loss + loss_cm_func = self.loss_func_dict['cm'] + # fusion_sigmoid loss + loss_cm = loss_cm_func[0](logit_dict['fusion'], focus_label) + loss['cm'] = loss_cm + + # error loss + err = F.interpolate( + logit_dict['error'], + label_dict['alpha'].shape[2:], + mode='bilinear', + align_corners=False) + err_label = (F.interpolate( + logit_dict['fusion'], + label_dict['alpha'].shape[2:], + mode='bilinear', + align_corners=False) - label_dict['alpha']).abs() + loss_err = self.loss_func_dict['err'][0](err, err_label) + loss['err'] = loss_err + + loss_all = 0.25 * loss_glance + 0.25 * loss_focus + 0.25 * loss_cm + loss_err + + # refine loss + if self.if_refine: + loss_refine = self.loss_func_dict['refine'][0](logit_dict['refine'], + label_dict['alpha']) + loss['refine'] = loss_refine + loss_all = loss_all + loss_refine + + loss['all'] = loss_all + return loss + + def fusion(self, glance_sigmoid, focus_sigmoid): + # glance_sigmoid [N, 3, H, W]. + # In index, 0 is foreground, 1 is transition, 2 is backbone. + # After fusion, the foreground is 1, the background is 0, and the transion is between (0, 1). + index = paddle.argmax(glance_sigmoid, axis=1, keepdim=True) + transition_mask = (index == 1).astype('float32') + fg = (index == 0).astype('float32') + fusion_sigmoid = focus_sigmoid * transition_mask + fg + return fusion_sigmoid + + def init_weight(self): + if self.pretrained is not None: + utils.load_entire_model(self, self.pretrained) + + +class Refiner(nn.Layer): + ''' + Refiner refines the coarse output to full resolution. + + Args: + kernel_size: The convolution kernel_size. Options: [1, 3]. Default: 3. + ''' + + def __init__(self, kernel_size=3): + super().__init__() + if kernel_size not in [1, 3]: + raise ValueError("kernel_size must be in [1, 3]") + + self.kernel_size = kernel_size + + channels = [32, 24, 16, 12, 1] + self.conv1 = layers.ConvBNReLU( + channels[0] + 4 + 3, + channels[1], + kernel_size, + padding=0, + bias_attr=False) + self.conv2 = layers.ConvBNReLU( + channels[1], channels[2], kernel_size, padding=0, bias_attr=False) + self.conv3 = layers.ConvBNReLU( + channels[2] + 3, + channels[3], + kernel_size, + padding=0, + bias_attr=False) + self.conv4 = nn.Conv2D( + channels[3], channels[4], kernel_size, padding=0, bias_attr=True) + + def forward(self, src, pha, err, hid, tri): + ''' + Args: + src: (B, 3, H, W) full resolution source image. + pha: (B, 1, Hc, Wc) coarse alpha prediction. + err: (B, 1, Hc, Hc) coarse error prediction. + hid: (B, 32, Hc, Hc) coarse hidden encoding. + tri: (B, 1, Hc, Hc) trimap prediction. + ''' + h_full, w_full = paddle.shape(src)[2:] + h_half, w_half = h_full // 2, w_full // 2 + h_quat, w_quat = h_full // 4, w_full // 4 + + x = paddle.concat([hid, pha, tri], axis=1) + x = F.interpolate( + x, + paddle.concat((h_half, w_half)), + mode='bilinear', + align_corners=False) + y = F.interpolate( + src, + paddle.concat((h_half, w_half)), + mode='bilinear', + align_corners=False) + + if self.kernel_size == 3: + x = F.pad(x, [3, 3, 3, 3]) + y = F.pad(y, [3, 3, 3, 3]) + + x = self.conv1(paddle.concat([x, y], axis=1)) + x = self.conv2(x) + + if self.kernel_size == 3: + x = F.interpolate(x, paddle.concat((h_full + 4, w_full + 4))) + y = F.pad(src, [2, 2, 2, 2]) + else: + x = F.interpolate( + x, paddle.concat((h_full, w_full)), mode='nearest') + y = src + + x = self.conv3(paddle.concat([x, y], axis=1)) + x = self.conv4(x) + + pha = x + return pha diff --git a/ppmatting/models/layers/__init__.py b/ppmatting/models/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..31eba2cacd64eddaf0734495b5a992a86b7bad37 --- /dev/null +++ b/ppmatting/models/layers/__init__.py @@ -0,0 +1,15 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .gca_module import GuidedCxtAtten diff --git a/ppmatting/models/layers/gca_module.py b/ppmatting/models/layers/gca_module.py new file mode 100644 index 0000000000000000000000000000000000000000..ba8654efc9bd24de2e127393ad8338d21964e4a5 --- /dev/null +++ b/ppmatting/models/layers/gca_module.py @@ -0,0 +1,211 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# The gca code was heavily based on https://github.com/Yaoyi-Li/GCA-Matting +# and https://github.com/open-mmlab/mmediting + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from paddleseg.cvlibs import param_init + + +class GuidedCxtAtten(nn.Layer): + def __init__(self, + out_channels, + guidance_channels, + kernel_size=3, + stride=1, + rate=2): + super().__init__() + + self.kernel_size = kernel_size + self.rate = rate + self.stride = stride + self.guidance_conv = nn.Conv2D( + in_channels=guidance_channels, + out_channels=guidance_channels // 2, + kernel_size=1) + + self.out_conv = nn.Sequential( + nn.Conv2D( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=1, + bias_attr=False), + nn.BatchNorm(out_channels)) + + self.init_weight() + + def init_weight(self): + param_init.xavier_uniform(self.guidance_conv.weight) + param_init.constant_init(self.guidance_conv.bias, value=0.0) + param_init.xavier_uniform(self.out_conv[0].weight) + param_init.constant_init(self.out_conv[1].weight, value=1e-3) + param_init.constant_init(self.out_conv[1].bias, value=0.0) + + def forward(self, img_feat, alpha_feat, unknown=None, softmax_scale=1.): + + img_feat = self.guidance_conv(img_feat) + img_feat = F.interpolate( + img_feat, scale_factor=1 / self.rate, mode='nearest') + + # process unknown mask + unknown, softmax_scale = self.process_unknown_mask(unknown, img_feat, + softmax_scale) + + img_ps, alpha_ps, unknown_ps = self.extract_feature_maps_patches( + img_feat, alpha_feat, unknown) + + self_mask = self.get_self_correlation_mask(img_feat) + + # split tensors by batch dimension; tuple is returned + img_groups = paddle.split(img_feat, 1, axis=0) + img_ps_groups = paddle.split(img_ps, 1, axis=0) + alpha_ps_groups = paddle.split(alpha_ps, 1, axis=0) + unknown_ps_groups = paddle.split(unknown_ps, 1, axis=0) + scale_groups = paddle.split(softmax_scale, 1, axis=0) + groups = (img_groups, img_ps_groups, alpha_ps_groups, unknown_ps_groups, + scale_groups) + + y = [] + + for img_i, img_ps_i, alpha_ps_i, unknown_ps_i, scale_i in zip(*groups): + # conv for compare + similarity_map = self.compute_similarity_map(img_i, img_ps_i) + + gca_score = self.compute_guided_attention_score( + similarity_map, unknown_ps_i, scale_i, self_mask) + + yi = self.propagate_alpha_feature(gca_score, alpha_ps_i) + + y.append(yi) + + y = paddle.concat(y, axis=0) # back to the mini-batch + y = paddle.reshape(y, alpha_feat.shape) + + y = self.out_conv(y) + alpha_feat + + return y + + def extract_feature_maps_patches(self, img_feat, alpha_feat, unknown): + + # extract image feature patches with shape: + # (N, img_h*img_w, img_c, img_ks, img_ks) + img_ks = self.kernel_size + img_ps = self.extract_patches(img_feat, img_ks, self.stride) + + # extract alpha feature patches with shape: + # (N, img_h*img_w, alpha_c, alpha_ks, alpha_ks) + alpha_ps = self.extract_patches(alpha_feat, self.rate * 2, self.rate) + + # extract unknown mask patches with shape: (N, img_h*img_w, 1, 1) + unknown_ps = self.extract_patches(unknown, img_ks, self.stride) + unknown_ps = unknown_ps.squeeze(axis=2) # squeeze channel dimension + unknown_ps = unknown_ps.mean(axis=[2, 3], keepdim=True) + + return img_ps, alpha_ps, unknown_ps + + def extract_patches(self, x, kernel_size, stride): + n, c, _, _ = x.shape + x = self.pad(x, kernel_size, stride) + x = F.unfold(x, [kernel_size, kernel_size], strides=[stride, stride]) + x = paddle.transpose(x, (0, 2, 1)) + x = paddle.reshape(x, (n, -1, c, kernel_size, kernel_size)) + + return x + + def pad(self, x, kernel_size, stride): + left = (kernel_size - stride + 1) // 2 + right = (kernel_size - stride) // 2 + pad = (left, right, left, right) + return F.pad(x, pad, mode='reflect') + + def compute_guided_attention_score(self, similarity_map, unknown_ps, scale, + self_mask): + # scale the correlation with predicted scale factor for known and + # unknown area + unknown_scale, known_scale = scale[0] + out = similarity_map * ( + unknown_scale * paddle.greater_than(unknown_ps, + paddle.to_tensor([0.])) + + known_scale * paddle.less_equal(unknown_ps, paddle.to_tensor([0.]))) + # mask itself, self-mask only applied to unknown area + out = out + self_mask * unknown_ps + gca_score = F.softmax(out, axis=1) + + return gca_score + + def propagate_alpha_feature(self, gca_score, alpha_ps): + + alpha_ps = alpha_ps[0] # squeeze dim 0 + if self.rate == 1: + gca_score = self.pad(gca_score, kernel_size=2, stride=1) + alpha_ps = paddle.transpose(alpha_ps, (1, 0, 2, 3)) + out = F.conv2d(gca_score, alpha_ps) / 4. + else: + out = F.conv2d_transpose( + gca_score, alpha_ps, stride=self.rate, padding=1) / 4. + + return out + + def compute_similarity_map(self, img_feat, img_ps): + img_ps = img_ps[0] # squeeze dim 0 + # convolve the feature to get correlation (similarity) map + img_ps_normed = img_ps / paddle.clip(self.l2_norm(img_ps), 1e-4) + img_feat = F.pad(img_feat, (1, 1, 1, 1), mode='reflect') + similarity_map = F.conv2d(img_feat, img_ps_normed) + + return similarity_map + + def get_self_correlation_mask(self, img_feat): + _, _, h, w = img_feat.shape + self_mask = F.one_hot( + paddle.reshape(paddle.arange(h * w), (h, w)), + num_classes=int(h * w)) + + self_mask = paddle.transpose(self_mask, (2, 0, 1)) + self_mask = paddle.reshape(self_mask, (1, h * w, h, w)) + + return self_mask * (-1e4) + + def process_unknown_mask(self, unknown, img_feat, softmax_scale): + + n, _, h, w = img_feat.shape + + if unknown is not None: + unknown = unknown.clone() + unknown = F.interpolate( + unknown, scale_factor=1 / self.rate, mode='nearest') + unknown_mean = unknown.mean(axis=[2, 3]) + known_mean = 1 - unknown_mean + unknown_scale = paddle.clip( + paddle.sqrt(unknown_mean / known_mean), 0.1, 10) + known_scale = paddle.clip( + paddle.sqrt(known_mean / unknown_mean), 0.1, 10) + softmax_scale = paddle.concat([unknown_scale, known_scale], axis=1) + else: + unknown = paddle.ones([n, 1, h, w]) + softmax_scale = paddle.reshape( + paddle.to_tensor([softmax_scale, softmax_scale]), (1, 2)) + softmax_scale = paddle.expand(softmax_scale, (n, 2)) + + return unknown, softmax_scale + + @staticmethod + def l2_norm(x): + x = x**2 + x = x.sum(axis=[1, 2, 3], keepdim=True) + return paddle.sqrt(x) diff --git a/ppmatting/models/losses/__init__.py b/ppmatting/models/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..039c22ee65cafd68a34ae6677cf7c7f8607b2c54 --- /dev/null +++ b/ppmatting/models/losses/__init__.py @@ -0,0 +1 @@ +from .loss import * diff --git a/ppmatting/models/losses/loss.py b/ppmatting/models/losses/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..750cb7b33b075c0c890e72a44ba041ad11b1bc4a --- /dev/null +++ b/ppmatting/models/losses/loss.py @@ -0,0 +1,163 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from paddleseg.cvlibs import manager +import cv2 + + +@manager.LOSSES.add_component +class MRSD(nn.Layer): + def __init__(self, eps=1e-6): + super().__init__() + self.eps = eps + + def forward(self, logit, label, mask=None): + """ + Forward computation. + + Args: + logit (Tensor): Logit tensor, the data type is float32, float64. + label (Tensor): Label tensor, the data type is float32, float64. The shape should equal to logit. + mask (Tensor, optional): The mask where the loss valid. Default: None. + """ + if len(label.shape) == 3: + label = label.unsqueeze(1) + sd = paddle.square(logit - label) + loss = paddle.sqrt(sd + self.eps) + if mask is not None: + mask = mask.astype('float32') + if len(mask.shape) == 3: + mask = mask.unsqueeze(1) + loss = loss * mask + loss = loss.sum() / (mask.sum() + self.eps) + mask.stop_gradient = True + else: + loss = loss.mean() + + return loss + + +@manager.LOSSES.add_component +class GradientLoss(nn.Layer): + def __init__(self, eps=1e-6): + super().__init__() + self.kernel_x, self.kernel_y = self.sobel_kernel() + self.eps = eps + + def forward(self, logit, label, mask=None): + if len(label.shape) == 3: + label = label.unsqueeze(1) + if mask is not None: + if len(mask.shape) == 3: + mask = mask.unsqueeze(1) + logit = logit * mask + label = label * mask + loss = paddle.sum( + F.l1_loss(self.sobel(logit), self.sobel(label), 'none')) / ( + mask.sum() + self.eps) + else: + loss = F.l1_loss(self.sobel(logit), self.sobel(label), 'mean') + + return loss + + def sobel(self, input): + """Using Sobel to compute gradient. Return the magnitude.""" + if not len(input.shape) == 4: + raise ValueError("Invalid input shape, we expect NCHW, but it is ", + input.shape) + + n, c, h, w = input.shape + + input_pad = paddle.reshape(input, (n * c, 1, h, w)) + input_pad = F.pad(input_pad, pad=[1, 1, 1, 1], mode='replicate') + + grad_x = F.conv2d(input_pad, self.kernel_x, padding=0) + grad_y = F.conv2d(input_pad, self.kernel_y, padding=0) + + mag = paddle.sqrt(grad_x * grad_x + grad_y * grad_y + self.eps) + mag = paddle.reshape(mag, (n, c, h, w)) + + return mag + + def sobel_kernel(self): + kernel_x = paddle.to_tensor([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], + [-1.0, 0.0, 1.0]]).astype('float32') + kernel_x = kernel_x / kernel_x.abs().sum() + kernel_y = kernel_x.transpose([1, 0]) + kernel_x = kernel_x.unsqueeze(0).unsqueeze(0) + kernel_y = kernel_y.unsqueeze(0).unsqueeze(0) + kernel_x.stop_gradient = True + kernel_y.stop_gradient = True + return kernel_x, kernel_y + + +@manager.LOSSES.add_component +class LaplacianLoss(nn.Layer): + """ + Laplacian loss is refer to + https://github.com/JizhiziLi/AIM/blob/master/core/evaluate.py#L83 + """ + + def __init__(self): + super().__init__() + self.gauss_kernel = self.build_gauss_kernel( + size=5, sigma=1.0, n_channels=1) + + def forward(self, logit, label, mask=None): + if len(label.shape) == 3: + label = label.unsqueeze(1) + if mask is not None: + if len(mask.shape) == 3: + mask = mask.unsqueeze(1) + logit = logit * mask + label = label * mask + pyr_label = self.laplacian_pyramid(label, self.gauss_kernel, 5) + pyr_logit = self.laplacian_pyramid(logit, self.gauss_kernel, 5) + loss = sum(F.l1_loss(a, b) for a, b in zip(pyr_label, pyr_logit)) + + return loss + + def build_gauss_kernel(self, size=5, sigma=1.0, n_channels=1): + if size % 2 != 1: + raise ValueError("kernel size must be uneven") + grid = np.float32(np.mgrid[0:size, 0:size].T) + gaussian = lambda x: np.exp((x - size // 2)**2 / (-2 * sigma**2))**2 + kernel = np.sum(gaussian(grid), axis=2) + kernel /= np.sum(kernel) + kernel = np.tile(kernel, (n_channels, 1, 1)) + kernel = paddle.to_tensor(kernel[:, None, :, :]) + kernel.stop_gradient = True + return kernel + + def conv_gauss(self, input, kernel): + n_channels, _, kh, kw = kernel.shape + x = F.pad(input, (kh // 2, kw // 2, kh // 2, kh // 2), mode='replicate') + x = F.conv2d(x, kernel, groups=n_channels) + + return x + + def laplacian_pyramid(self, input, kernel, max_levels=5): + current = input + pyr = [] + for level in range(max_levels): + filtered = self.conv_gauss(current, kernel) + diff = current - filtered + pyr.append(diff) + current = F.avg_pool2d(filtered, 2) + pyr.append(current) + return pyr diff --git a/ppmatting/models/modnet.py b/ppmatting/models/modnet.py new file mode 100644 index 0000000000000000000000000000000000000000..ecadfdd1a1710980e36a23bc82717e3081ad64e9 --- /dev/null +++ b/ppmatting/models/modnet.py @@ -0,0 +1,494 @@ +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +import numpy as np +import scipy +import paddleseg +from paddleseg.models import layers, losses +from paddleseg import utils +from paddleseg.cvlibs import manager, param_init + + +@manager.MODELS.add_component +class MODNet(nn.Layer): + """ + The MODNet implementation based on PaddlePaddle. + + The original article refers to + Zhanghan Ke, et, al. "Is a Green Screen Really Necessary for Real-Time Portrait Matting?" + (https://arxiv.org/pdf/2011.11961.pdf). + + Args: + backbone: backbone model. + hr(int, optional): The channels of high resolutions branch. Defautl: None. + pretrained(str, optional): The path of pretrianed model. Defautl: None. + + """ + + def __init__(self, backbone, hr_channels=32, pretrained=None): + super().__init__() + self.backbone = backbone + self.pretrained = pretrained + self.head = MODNetHead( + hr_channels=hr_channels, backbone_channels=backbone.feat_channels) + self.init_weight() + self.blurer = GaussianBlurLayer(1, 3) + self.loss_func_dict = None + + def forward(self, inputs): + """ + If training, return a dict. + If evaluation, return the final alpha prediction. + """ + x = inputs['img'] + feat_list = self.backbone(x) + y = self.head(inputs=inputs, feat_list=feat_list) + if self.training: + loss = self.loss(y, inputs) + return y, loss + else: + return y + + def loss(self, logit_dict, label_dict, loss_func_dict=None): + if loss_func_dict is None: + if self.loss_func_dict is None: + self.loss_func_dict = defaultdict(list) + self.loss_func_dict['semantic'].append(paddleseg.models.MSELoss( + )) + self.loss_func_dict['detail'].append(paddleseg.models.L1Loss()) + self.loss_func_dict['fusion'].append(paddleseg.models.L1Loss()) + self.loss_func_dict['fusion'].append(paddleseg.models.L1Loss()) + else: + self.loss_func_dict = loss_func_dict + + loss = {} + # semantic loss + semantic_gt = F.interpolate( + label_dict['alpha'], + scale_factor=1 / 16, + mode='bilinear', + align_corners=False) + semantic_gt = self.blurer(semantic_gt) + # semantic_gt.stop_gradient=True + loss['semantic'] = self.loss_func_dict['semantic'][0]( + logit_dict['semantic'], semantic_gt) + + # detail loss + trimap = label_dict['trimap'] + mask = (trimap == 128).astype('float32') + logit_detail = logit_dict['detail'] * mask + label_detail = label_dict['alpha'] * mask + loss_detail = self.loss_func_dict['detail'][0](logit_detail, + label_detail) + loss_detail = loss_detail / (mask.mean() + 1e-6) + loss['detail'] = 10 * loss_detail + + # fusion loss + matte = logit_dict['matte'] + alpha = label_dict['alpha'] + transition_mask = label_dict['trimap'] == 128 + matte_boundary = paddle.where(transition_mask, matte, alpha) + # l1 loss + loss_fusion_l1 = self.loss_func_dict['fusion'][0]( + matte, alpha) + 4 * self.loss_func_dict['fusion'][0](matte_boundary, + alpha) + # composition loss + loss_fusion_comp = self.loss_func_dict['fusion'][1]( + matte * label_dict['img'], alpha * + label_dict['img']) + 4 * self.loss_func_dict['fusion'][1]( + matte_boundary * label_dict['img'], alpha * label_dict['img']) + # consisten loss with semantic + transition_mask = F.interpolate( + label_dict['trimap'], + scale_factor=1 / 16, + mode='nearest', + align_corners=False) + transition_mask = transition_mask == 128 + matte_con_sem = F.interpolate( + matte, scale_factor=1 / 16, mode='bilinear', align_corners=False) + matte_con_sem = self.blurer(matte_con_sem) + logit_semantic = logit_dict['semantic'].clone() + logit_semantic.stop_gradient = True + matte_con_sem = paddle.where(transition_mask, logit_semantic, + matte_con_sem) + if False: + import cv2 + matte_con_sem_num = matte_con_sem.numpy() + matte_con_sem_num = matte_con_sem_num[0].squeeze() + matte_con_sem_num = (matte_con_sem_num * 255).astype('uint8') + semantic = logit_dict['semantic'].numpy() + semantic = semantic[0].squeeze() + semantic = (semantic * 255).astype('uint8') + transition_mask = transition_mask.astype('uint8') + transition_mask = transition_mask.numpy() + transition_mask = (transition_mask[0].squeeze()) * 255 + cv2.imwrite('matte_con.png', matte_con_sem_num) + cv2.imwrite('semantic.png', semantic) + cv2.imwrite('transition.png', transition_mask) + mse_loss = paddleseg.models.MSELoss() + loss_fusion_con_sem = mse_loss(matte_con_sem, logit_dict['semantic']) + loss_fusion = loss_fusion_l1 + loss_fusion_comp + loss_fusion_con_sem + loss['fusion'] = loss_fusion + loss['fusion_l1'] = loss_fusion_l1 + loss['fusion_comp'] = loss_fusion_comp + loss['fusion_con_sem'] = loss_fusion_con_sem + + loss['all'] = loss['semantic'] + loss['detail'] + loss['fusion'] + + return loss + + def init_weight(self): + if self.pretrained is not None: + utils.load_entire_model(self, self.pretrained) + + +class MODNetHead(nn.Layer): + def __init__(self, hr_channels, backbone_channels): + super().__init__() + + self.lr_branch = LRBranch(backbone_channels) + self.hr_branch = HRBranch(hr_channels, backbone_channels) + self.f_branch = FusionBranch(hr_channels, backbone_channels) + self.init_weight() + + def forward(self, inputs, feat_list): + pred_semantic, lr8x, [enc2x, enc4x] = self.lr_branch(feat_list) + pred_detail, hr2x = self.hr_branch(inputs['img'], enc2x, enc4x, lr8x) + pred_matte = self.f_branch(inputs['img'], lr8x, hr2x) + + if self.training: + logit_dict = { + 'semantic': pred_semantic, + 'detail': pred_detail, + 'matte': pred_matte + } + return logit_dict + else: + return pred_matte + + def init_weight(self): + for layer in self.sublayers(): + if isinstance(layer, nn.Conv2D): + param_init.kaiming_uniform(layer.weight) + + +class FusionBranch(nn.Layer): + def __init__(self, hr_channels, enc_channels): + super().__init__() + self.conv_lr4x = Conv2dIBNormRelu( + enc_channels[2], hr_channels, 5, stride=1, padding=2) + + self.conv_f2x = Conv2dIBNormRelu( + 2 * hr_channels, hr_channels, 3, stride=1, padding=1) + self.conv_f = nn.Sequential( + Conv2dIBNormRelu( + hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1), + Conv2dIBNormRelu( + int(hr_channels / 2), + 1, + 1, + stride=1, + padding=0, + with_ibn=False, + with_relu=False)) + + def forward(self, img, lr8x, hr2x): + lr4x = F.interpolate( + lr8x, scale_factor=2, mode='bilinear', align_corners=False) + lr4x = self.conv_lr4x(lr4x) + lr2x = F.interpolate( + lr4x, scale_factor=2, mode='bilinear', align_corners=False) + + f2x = self.conv_f2x(paddle.concat((lr2x, hr2x), axis=1)) + f = F.interpolate( + f2x, scale_factor=2, mode='bilinear', align_corners=False) + f = self.conv_f(paddle.concat((f, img), axis=1)) + pred_matte = F.sigmoid(f) + + return pred_matte + + +class HRBranch(nn.Layer): + """ + High Resolution Branch of MODNet + """ + + def __init__(self, hr_channels, enc_channels): + super().__init__() + + self.tohr_enc2x = Conv2dIBNormRelu( + enc_channels[0], hr_channels, 1, stride=1, padding=0) + self.conv_enc2x = Conv2dIBNormRelu( + hr_channels + 3, hr_channels, 3, stride=2, padding=1) + + self.tohr_enc4x = Conv2dIBNormRelu( + enc_channels[1], hr_channels, 1, stride=1, padding=0) + self.conv_enc4x = Conv2dIBNormRelu( + 2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1) + + self.conv_hr4x = nn.Sequential( + Conv2dIBNormRelu( + 2 * hr_channels + enc_channels[2] + 3, + 2 * hr_channels, + 3, + stride=1, + padding=1), + Conv2dIBNormRelu( + 2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1), + Conv2dIBNormRelu( + 2 * hr_channels, hr_channels, 3, stride=1, padding=1)) + + self.conv_hr2x = nn.Sequential( + Conv2dIBNormRelu( + 2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1), + Conv2dIBNormRelu( + 2 * hr_channels, hr_channels, 3, stride=1, padding=1), + Conv2dIBNormRelu( + hr_channels, hr_channels, 3, stride=1, padding=1), + Conv2dIBNormRelu( + hr_channels, hr_channels, 3, stride=1, padding=1)) + + self.conv_hr = nn.Sequential( + Conv2dIBNormRelu( + hr_channels + 3, hr_channels, 3, stride=1, padding=1), + Conv2dIBNormRelu( + hr_channels, + 1, + 1, + stride=1, + padding=0, + with_ibn=False, + with_relu=False)) + + def forward(self, img, enc2x, enc4x, lr8x): + img2x = F.interpolate( + img, scale_factor=1 / 2, mode='bilinear', align_corners=False) + img4x = F.interpolate( + img, scale_factor=1 / 4, mode='bilinear', align_corners=False) + + enc2x = self.tohr_enc2x(enc2x) + hr4x = self.conv_enc2x(paddle.concat((img2x, enc2x), axis=1)) + + enc4x = self.tohr_enc4x(enc4x) + hr4x = self.conv_enc4x(paddle.concat((hr4x, enc4x), axis=1)) + + lr4x = F.interpolate( + lr8x, scale_factor=2, mode='bilinear', align_corners=False) + hr4x = self.conv_hr4x(paddle.concat((hr4x, lr4x, img4x), axis=1)) + + hr2x = F.interpolate( + hr4x, scale_factor=2, mode='bilinear', align_corners=False) + hr2x = self.conv_hr2x(paddle.concat((hr2x, enc2x), axis=1)) + + pred_detail = None + if self.training: + hr = F.interpolate( + hr2x, scale_factor=2, mode='bilinear', align_corners=False) + hr = self.conv_hr(paddle.concat((hr, img), axis=1)) + pred_detail = F.sigmoid(hr) + + return pred_detail, hr2x + + +class LRBranch(nn.Layer): + def __init__(self, backbone_channels): + super().__init__() + self.se_block = SEBlock(backbone_channels[4], reduction=4) + self.conv_lr16x = Conv2dIBNormRelu( + backbone_channels[4], backbone_channels[3], 5, stride=1, padding=2) + self.conv_lr8x = Conv2dIBNormRelu( + backbone_channels[3], backbone_channels[2], 5, stride=1, padding=2) + self.conv_lr = Conv2dIBNormRelu( + backbone_channels[2], + 1, + 3, + stride=2, + padding=1, + with_ibn=False, + with_relu=False) + + def forward(self, feat_list): + enc2x, enc4x, enc32x = feat_list[0], feat_list[1], feat_list[4] + + enc32x = self.se_block(enc32x) + lr16x = F.interpolate( + enc32x, scale_factor=2, mode='bilinear', align_corners=False) + lr16x = self.conv_lr16x(lr16x) + lr8x = F.interpolate( + lr16x, scale_factor=2, mode='bilinear', align_corners=False) + lr8x = self.conv_lr8x(lr8x) + + pred_semantic = None + if self.training: + lr = self.conv_lr(lr8x) + pred_semantic = F.sigmoid(lr) + + return pred_semantic, lr8x, [enc2x, enc4x] + + +class IBNorm(nn.Layer): + """ + Combine Instance Norm and Batch Norm into One Layer + """ + + def __init__(self, in_channels): + super().__init__() + self.bnorm_channels = in_channels // 2 + self.inorm_channels = in_channels - self.bnorm_channels + + self.bnorm = nn.BatchNorm2D(self.bnorm_channels) + self.inorm = nn.InstanceNorm2D(self.inorm_channels) + + def forward(self, x): + bn_x = self.bnorm(x[:, :self.bnorm_channels, :, :]) + in_x = self.inorm(x[:, self.bnorm_channels:, :, :]) + + return paddle.concat((bn_x, in_x), 1) + + +class Conv2dIBNormRelu(nn.Layer): + """ + Convolution + IBNorm + Relu + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias_attr=None, + with_ibn=True, + with_relu=True): + + super().__init__() + + layers = [ + nn.Conv2D( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias_attr=bias_attr) + ] + + if with_ibn: + layers.append(IBNorm(out_channels)) + + if with_relu: + layers.append(nn.ReLU()) + + self.layers = nn.Sequential(*layers) + + def forward(self, x): + return self.layers(x) + + +class SEBlock(nn.Layer): + """ + SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf + """ + + def __init__(self, num_channels, reduction=1): + super().__init__() + self.pool = nn.AdaptiveAvgPool2D(1) + self.conv = nn.Sequential( + nn.Conv2D( + num_channels, + int(num_channels // reduction), + 1, + bias_attr=False), + nn.ReLU(), + nn.Conv2D( + int(num_channels // reduction), + num_channels, + 1, + bias_attr=False), + nn.Sigmoid()) + + def forward(self, x): + w = self.pool(x) + w = self.conv(w) + return w * x + + +class GaussianBlurLayer(nn.Layer): + """ Add Gaussian Blur to a 4D tensors + This layer takes a 4D tensor of {N, C, H, W} as input. + The Gaussian blur will be performed in given channel number (C) splitly. + """ + + def __init__(self, channels, kernel_size): + """ + Args: + channels (int): Channel for input tensor + kernel_size (int): Size of the kernel used in blurring + """ + + super(GaussianBlurLayer, self).__init__() + self.channels = channels + self.kernel_size = kernel_size + assert self.kernel_size % 2 != 0 + + self.op = nn.Sequential( + nn.Pad2D( + int(self.kernel_size / 2), mode='reflect'), + nn.Conv2D( + channels, + channels, + self.kernel_size, + stride=1, + padding=0, + bias_attr=False, + groups=channels)) + + self._init_kernel() + self.op[1].weight.stop_gradient = True + + def forward(self, x): + """ + Args: + x (paddle.Tensor): input 4D tensor + Returns: + paddle.Tensor: Blurred version of the input + """ + + if not len(list(x.shape)) == 4: + print('\'GaussianBlurLayer\' requires a 4D tensor as input\n') + exit() + elif not x.shape[1] == self.channels: + print('In \'GaussianBlurLayer\', the required channel ({0}) is' + 'not the same as input ({1})\n'.format(self.channels, x.shape[ + 1])) + exit() + + return self.op(x) + + def _init_kernel(self): + sigma = 0.3 * ((self.kernel_size - 1) * 0.5 - 1) + 0.8 + + n = np.zeros((self.kernel_size, self.kernel_size)) + i = int(self.kernel_size / 2) + n[i, i] = 1 + kernel = scipy.ndimage.gaussian_filter(n, sigma) + kernel = kernel.astype('float32') + kernel = kernel[np.newaxis, np.newaxis, :, :] + paddle.assign(kernel, self.op[1].weight) diff --git a/ppmatting/models/ppmatting.py b/ppmatting/models/ppmatting.py new file mode 100644 index 0000000000000000000000000000000000000000..2ed14528b5e598eda3a8fd6030a51ecc81dc6e3c --- /dev/null +++ b/ppmatting/models/ppmatting.py @@ -0,0 +1,338 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +import time + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +import paddleseg +from paddleseg.models import layers +from paddleseg import utils +from paddleseg.cvlibs import manager + +from ppmatting.models.losses import MRSD, GradientLoss +from ppmatting.models.backbone import resnet_vd + + +@manager.MODELS.add_component +class PPMatting(nn.Layer): + """ + The PPMattinh implementation based on PaddlePaddle. + + The original article refers to + Guowei Chen, et, al. "PP-Matting: High-Accuracy Natural Image Matting" + (https://arxiv.org/pdf/2204.09433.pdf). + + Args: + backbone: backbone model. + pretrained(str, optional): The path of pretrianed model. Defautl: None. + + """ + + def __init__(self, backbone, pretrained=None): + super().__init__() + self.backbone = backbone + self.pretrained = pretrained + self.loss_func_dict = self.get_loss_func_dict() + + self.backbone_channels = backbone.feat_channels + + self.scb = SCB(self.backbone_channels[-1]) + + self.hrdb = HRDB( + self.backbone_channels[0] + self.backbone_channels[1], + scb_channels=self.scb.out_channels, + gf_index=[0, 2, 4]) + + self.init_weight() + + def forward(self, inputs): + x = inputs['img'] + input_shape = paddle.shape(x) + fea_list = self.backbone(x) + + scb_logits = self.scb(fea_list[-1]) + semantic_map = F.softmax(scb_logits[-1], axis=1) + + fea0 = F.interpolate( + fea_list[0], input_shape[2:], mode='bilinear', align_corners=False) + fea1 = F.interpolate( + fea_list[1], input_shape[2:], mode='bilinear', align_corners=False) + hrdb_input = paddle.concat([fea0, fea1], 1) + hrdb_logit = self.hrdb(hrdb_input, scb_logits) + detail_map = F.sigmoid(hrdb_logit) + fusion = self.fusion(semantic_map, detail_map) + + if self.training: + logit_dict = { + 'semantic': semantic_map, + 'detail': detail_map, + 'fusion': fusion + } + loss_dict = self.loss(logit_dict, inputs) + return logit_dict, loss_dict + else: + return fusion + + def get_loss_func_dict(self): + loss_func_dict = defaultdict(list) + loss_func_dict['semantic'].append(nn.NLLLoss()) + loss_func_dict['detail'].append(MRSD()) + loss_func_dict['detail'].append(GradientLoss()) + loss_func_dict['fusion'].append(MRSD()) + loss_func_dict['fusion'].append(MRSD()) + loss_func_dict['fusion'].append(GradientLoss()) + return loss_func_dict + + def loss(self, logit_dict, label_dict): + loss = {} + + # semantic loss computation + # get semantic label + semantic_label = label_dict['trimap'] + semantic_label_trans = (semantic_label == 128).astype('int64') + semantic_label_bg = (semantic_label == 0).astype('int64') + semantic_label = semantic_label_trans + semantic_label_bg * 2 + loss_semantic = self.loss_func_dict['semantic'][0]( + paddle.log(logit_dict['semantic'] + 1e-6), + semantic_label.squeeze(1)) + loss['semantic'] = loss_semantic + + # detail loss computation + transparent = label_dict['trimap'] == 128 + detail_alpha_loss = self.loss_func_dict['detail'][0]( + logit_dict['detail'], label_dict['alpha'], transparent) + # gradient loss + detail_gradient_loss = self.loss_func_dict['detail'][1]( + logit_dict['detail'], label_dict['alpha'], transparent) + loss_detail = detail_alpha_loss + detail_gradient_loss + loss['detail'] = loss_detail + loss['detail_alpha'] = detail_alpha_loss + loss['detail_gradient'] = detail_gradient_loss + + # fusion loss + loss_fusion_func = self.loss_func_dict['fusion'] + # fusion_sigmoid loss + fusion_alpha_loss = loss_fusion_func[0](logit_dict['fusion'], + label_dict['alpha']) + # composion loss + comp_pred = logit_dict['fusion'] * label_dict['fg'] + ( + 1 - logit_dict['fusion']) * label_dict['bg'] + comp_gt = label_dict['alpha'] * label_dict['fg'] + ( + 1 - label_dict['alpha']) * label_dict['bg'] + fusion_composition_loss = loss_fusion_func[1](comp_pred, comp_gt) + # grandient loss + fusion_grad_loss = loss_fusion_func[2](logit_dict['fusion'], + label_dict['alpha']) + # fusion loss + loss_fusion = fusion_alpha_loss + fusion_composition_loss + fusion_grad_loss + loss['fusion'] = loss_fusion + loss['fusion_alpha'] = fusion_alpha_loss + loss['fusion_composition'] = fusion_composition_loss + loss['fusion_gradient'] = fusion_grad_loss + + loss[ + 'all'] = 0.25 * loss_semantic + 0.25 * loss_detail + 0.25 * loss_fusion + + return loss + + def fusion(self, semantic_map, detail_map): + # semantic_map [N, 3, H, W] + # In index, 0 is foreground, 1 is transition, 2 is backbone + # After fusion, the foreground is 1, the background is 0, and the transion is between [0, 1] + index = paddle.argmax(semantic_map, axis=1, keepdim=True) + transition_mask = (index == 1).astype('float32') + fg = (index == 0).astype('float32') + alpha = detail_map * transition_mask + fg + return alpha + + def init_weight(self): + if self.pretrained is not None: + utils.load_entire_model(self, self.pretrained) + + +class SCB(nn.Layer): + def __init__(self, in_channels): + super().__init__() + self.in_channels = [512 + in_channels, 512, 256, 128, 128, 64] + self.mid_channels = [512, 256, 128, 128, 64, 64] + self.out_channels = [256, 128, 64, 64, 64, 3] + + self.psp_module = layers.PPModule( + in_channels, + 512, + bin_sizes=(1, 3, 5), + dim_reduction=False, + align_corners=False) + + psp_upsamples = [2, 4, 8, 16] + self.psps = nn.LayerList([ + self.conv_up_psp(512, self.out_channels[i], psp_upsamples[i]) + for i in range(4) + ]) + + scb_list = [ + self._make_stage( + self.in_channels[i], + self.mid_channels[i], + self.out_channels[i], + padding=int(i == 0) + 1, + dilation=int(i == 0) + 1) + for i in range(len(self.in_channels) - 1) + ] + scb_list += [ + nn.Sequential( + layers.ConvBNReLU( + self.in_channels[-1], self.mid_channels[-1], 3, padding=1), + layers.ConvBNReLU( + self.mid_channels[-1], self.mid_channels[-1], 3, padding=1), + nn.Conv2D( + self.mid_channels[-1], self.out_channels[-1], 3, padding=1)) + ] + self.scb_stages = nn.LayerList(scb_list) + + def forward(self, x): + psp_x = self.psp_module(x) + psps = [psp(psp_x) for psp in self.psps] + + scb_logits = [] + for i, scb_stage in enumerate(self.scb_stages): + if i == 0: + x = scb_stage(paddle.concat((psp_x, x), 1)) + elif i <= len(psps): + x = scb_stage(paddle.concat((psps[i - 1], x), 1)) + else: + x = scb_stage(x) + scb_logits.append(x) + return scb_logits + + def conv_up_psp(self, in_channels, out_channels, up_sample): + return nn.Sequential( + layers.ConvBNReLU( + in_channels, out_channels, 3, padding=1), + nn.Upsample( + scale_factor=up_sample, mode='bilinear', align_corners=False)) + + def _make_stage(self, + in_channels, + mid_channels, + out_channels, + padding=1, + dilation=1): + layer_list = [ + layers.ConvBNReLU( + in_channels, mid_channels, 3, padding=1), layers.ConvBNReLU( + mid_channels, + mid_channels, + 3, + padding=padding, + dilation=dilation), layers.ConvBNReLU( + mid_channels, + out_channels, + 3, + padding=padding, + dilation=dilation), nn.Upsample( + scale_factor=2, + mode='bilinear', + align_corners=False) + ] + return nn.Sequential(*layer_list) + + +class HRDB(nn.Layer): + """ + The High-Resolution Detail Branch + + Args: + in_channels(int): The number of input channels. + scb_channels(list|tuple): The channels of scb logits + gf_index(list|tuple, optional): Which logit is selected as guidance flow from scb logits. Default: (0, 2, 4) + """ + + def __init__(self, in_channels, scb_channels, gf_index=(0, 2, 4)): + super().__init__() + self.gf_index = gf_index + self.gf_list = nn.LayerList( + [nn.Conv2D(scb_channels[i], 1, 1) for i in gf_index]) + + channels = [64, 32, 16, 8] + self.res_list = [ + resnet_vd.BasicBlock( + in_channels, channels[0], stride=1, shortcut=False) + ] + self.res_list += [ + resnet_vd.BasicBlock( + i, i, stride=1) for i in channels[1:-1] + ] + self.res_list = nn.LayerList(self.res_list) + + self.convs = nn.LayerList([ + nn.Conv2D( + channels[i], channels[i + 1], kernel_size=1) + for i in range(len(channels) - 1) + ]) + self.gates = nn.LayerList( + [GatedSpatailConv2d(i, i) for i in channels[1:]]) + + self.detail_conv = nn.Conv2D(channels[-1], 1, 1, bias_attr=False) + + def forward(self, x, scb_logits): + for i in range(len(self.res_list)): + x = self.res_list[i](x) + x = self.convs[i](x) + gf = self.gf_list[i](scb_logits[self.gf_index[i]]) + gf = F.interpolate( + gf, paddle.shape(x)[-2:], mode='bilinear', align_corners=False) + x = self.gates[i](x, gf) + return self.detail_conv(x) + + +class GatedSpatailConv2d(nn.Layer): + def __init__(self, + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + groups=1, + bias_attr=False): + super().__init__() + self._gate_conv = nn.Sequential( + layers.SyncBatchNorm(in_channels + 1), + nn.Conv2D( + in_channels + 1, in_channels + 1, kernel_size=1), + nn.ReLU(), + nn.Conv2D( + in_channels + 1, 1, kernel_size=1), + layers.SyncBatchNorm(1), + nn.Sigmoid()) + self.conv = nn.Conv2D( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias_attr=bias_attr) + + def forward(self, input_features, gating_features): + cat = paddle.concat([input_features, gating_features], axis=1) + alphas = self._gate_conv(cat) + x = input_features * (alphas + 1) + x = self.conv(x) + return x diff --git a/ppmatting/transforms/__init__.py b/ppmatting/transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..78a33bf7dbe3a673c59a176704340c8baed8ba00 --- /dev/null +++ b/ppmatting/transforms/__init__.py @@ -0,0 +1 @@ +from .transforms import * diff --git a/ppmatting/transforms/transforms.py b/ppmatting/transforms/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..afd28b4917a890890820e56785b81c841b2d387a --- /dev/null +++ b/ppmatting/transforms/transforms.py @@ -0,0 +1,791 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import random +import string + +import cv2 +import numpy as np +from paddleseg.transforms import functional +from paddleseg.cvlibs import manager +from paddleseg.utils import seg_env +from PIL import Image + + +@manager.TRANSFORMS.add_component +class Compose: + """ + Do transformation on input data with corresponding pre-processing and augmentation operations. + The shape of input data to all operations is [height, width, channels]. + """ + + def __init__(self, transforms, to_rgb=True): + if not isinstance(transforms, list): + raise TypeError('The transforms must be a list!') + self.transforms = transforms + self.to_rgb = to_rgb + + def __call__(self, data): + """ + Args: + data (dict): The data to transform. + + Returns: + dict: Data after transformation + """ + if 'trans_info' not in data: + data['trans_info'] = [] + for op in self.transforms: + data = op(data) + if data is None: + return None + + data['img'] = np.transpose(data['img'], (2, 0, 1)) + for key in data.get('gt_fields', []): + if len(data[key].shape) == 2: + continue + data[key] = np.transpose(data[key], (2, 0, 1)) + + return data + + +@manager.TRANSFORMS.add_component +class LoadImages: + def __init__(self, to_rgb=False): + self.to_rgb = to_rgb + + def __call__(self, data): + if isinstance(data['img'], str): + data['img'] = cv2.imread(data['img']) + for key in data.get('gt_fields', []): + if isinstance(data[key], str): + data[key] = cv2.imread(data[key], cv2.IMREAD_UNCHANGED) + # if alpha and trimap has 3 channels, extract one. + if key in ['alpha', 'trimap']: + if len(data[key].shape) > 2: + data[key] = data[key][:, :, 0] + + if self.to_rgb: + data['img'] = cv2.cvtColor(data['img'], cv2.COLOR_BGR2RGB) + for key in data.get('gt_fields', []): + if len(data[key].shape) == 2: + continue + data[key] = cv2.cvtColor(data[key], cv2.COLOR_BGR2RGB) + + return data + + +@manager.TRANSFORMS.add_component +class Resize: + def __init__(self, target_size=(512, 512), random_interp=False): + if isinstance(target_size, list) or isinstance(target_size, tuple): + if len(target_size) != 2: + raise ValueError( + '`target_size` should include 2 elements, but it is {}'. + format(target_size)) + else: + raise TypeError( + "Type of `target_size` is invalid. It should be list or tuple, but it is {}" + .format(type(target_size))) + + self.target_size = target_size + self.random_interp = random_interp + self.interps = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC] + + def __call__(self, data): + if self.random_interp: + interp = np.random.choice(self.interps) + else: + interp = cv2.INTER_LINEAR + data['trans_info'].append(('resize', data['img'].shape[0:2])) + data['img'] = functional.resize(data['img'], self.target_size, interp) + for key in data.get('gt_fields', []): + if key == 'trimap': + data[key] = functional.resize(data[key], self.target_size, + cv2.INTER_NEAREST) + else: + data[key] = functional.resize(data[key], self.target_size, + interp) + return data + + +@manager.TRANSFORMS.add_component +class RandomResize: + """ + Resize image to a size determinned by `scale` and `size`. + + Args: + size(tuple|list): The reference size to resize. A tuple or list with length 2. + scale(tupel|list, optional): A range of scale base on `size`. A tuple or list with length 2. Default: None. + """ + + def __init__(self, size=None, scale=None): + if isinstance(size, list) or isinstance(size, tuple): + if len(size) != 2: + raise ValueError( + '`size` should include 2 elements, but it is {}'.format( + size)) + elif size is not None: + raise TypeError( + "Type of `size` is invalid. It should be list or tuple, but it is {}" + .format(type(size))) + + if scale is not None: + if isinstance(scale, list) or isinstance(scale, tuple): + if len(scale) != 2: + raise ValueError( + '`scale` should include 2 elements, but it is {}'. + format(scale)) + else: + raise TypeError( + "Type of `scale` is invalid. It should be list or tuple, but it is {}" + .format(type(scale))) + self.size = size + self.scale = scale + + def __call__(self, data): + h, w = data['img'].shape[:2] + if self.scale is not None: + scale = np.random.uniform(self.scale[0], self.scale[1]) + else: + scale = 1. + if self.size is not None: + scale_factor = max(self.size[0] / w, self.size[1] / h) + else: + scale_factor = 1 + scale = scale * scale_factor + + w = int(round(w * scale)) + h = int(round(h * scale)) + data['img'] = functional.resize(data['img'], (w, h)) + for key in data.get('gt_fields', []): + if key == 'trimap': + data[key] = functional.resize(data[key], (w, h), + cv2.INTER_NEAREST) + else: + data[key] = functional.resize(data[key], (w, h)) + return data + + +@manager.TRANSFORMS.add_component +class ResizeByLong: + """ + Resize the long side of an image to given size, and then scale the other side proportionally. + + Args: + long_size (int): The target size of long side. + """ + + def __init__(self, long_size): + self.long_size = long_size + + def __call__(self, data): + data['trans_info'].append(('resize', data['img'].shape[0:2])) + data['img'] = functional.resize_long(data['img'], self.long_size) + for key in data.get('gt_fields', []): + if key == 'trimap': + data[key] = functional.resize_long(data[key], self.long_size, + cv2.INTER_NEAREST) + else: + data[key] = functional.resize_long(data[key], self.long_size) + return data + + +@manager.TRANSFORMS.add_component +class ResizeByShort: + """ + Resize the short side of an image to given size, and then scale the other side proportionally. + + Args: + short_size (int): The target size of short side. + """ + + def __init__(self, short_size): + self.short_size = short_size + + def __call__(self, data): + data['trans_info'].append(('resize', data['img'].shape[0:2])) + data['img'] = functional.resize_short(data['img'], self.short_size) + for key in data.get('gt_fields', []): + if key == 'trimap': + data[key] = functional.resize_short(data[key], self.short_size, + cv2.INTER_NEAREST) + else: + data[key] = functional.resize_short(data[key], self.short_size) + return data + + +@manager.TRANSFORMS.add_component +class ResizeToIntMult: + """ + Resize to some int muitple, d.g. 32. + """ + + def __init__(self, mult_int=32): + self.mult_int = mult_int + + def __call__(self, data): + data['trans_info'].append(('resize', data['img'].shape[0:2])) + + h, w = data['img'].shape[0:2] + rw = w - w % self.mult_int + rh = h - h % self.mult_int + data['img'] = functional.resize(data['img'], (rw, rh)) + for key in data.get('gt_fields', []): + if key == 'trimap': + data[key] = functional.resize(data[key], (rw, rh), + cv2.INTER_NEAREST) + else: + data[key] = functional.resize(data[key], (rw, rh)) + + return data + + +@manager.TRANSFORMS.add_component +class Normalize: + """ + Normalize an image. + + Args: + mean (list, optional): The mean value of a data set. Default: [0.5, 0.5, 0.5]. + std (list, optional): The standard deviation of a data set. Default: [0.5, 0.5, 0.5]. + + Raises: + ValueError: When mean/std is not list or any value in std is 0. + """ + + def __init__(self, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)): + self.mean = mean + self.std = std + if not (isinstance(self.mean, + (list, tuple)) and isinstance(self.std, + (list, tuple))): + raise ValueError( + "{}: input type is invalid. It should be list or tuple".format( + self)) + from functools import reduce + if reduce(lambda x, y: x * y, self.std) == 0: + raise ValueError('{}: std is invalid!'.format(self)) + + def __call__(self, data): + mean = np.array(self.mean)[np.newaxis, np.newaxis, :] + std = np.array(self.std)[np.newaxis, np.newaxis, :] + data['img'] = functional.normalize(data['img'], mean, std) + if 'fg' in data.get('gt_fields', []): + data['fg'] = functional.normalize(data['fg'], mean, std) + if 'bg' in data.get('gt_fields', []): + data['bg'] = functional.normalize(data['bg'], mean, std) + + return data + + +@manager.TRANSFORMS.add_component +class RandomCropByAlpha: + """ + Randomly crop while centered on uncertain area by a certain probability. + + Args: + crop_size (tuple|list): The size you want to crop from image. + p (float): The probability centered on uncertain area. + + """ + + def __init__(self, crop_size=((320, 320), (480, 480), (640, 640)), + prob=0.5): + self.crop_size = crop_size + self.prob = prob + + def __call__(self, data): + idex = np.random.randint(low=0, high=len(self.crop_size)) + crop_w, crop_h = self.crop_size[idex] + + img_h = data['img'].shape[0] + img_w = data['img'].shape[1] + if np.random.rand() < self.prob: + crop_center = np.where((data['alpha'] > 0) & (data['alpha'] < 255)) + center_h_array, center_w_array = crop_center + if len(center_h_array) == 0: + return data + rand_ind = np.random.randint(len(center_h_array)) + center_h = center_h_array[rand_ind] + center_w = center_w_array[rand_ind] + delta_h = crop_h // 2 + delta_w = crop_w // 2 + start_h = max(0, center_h - delta_h) + start_w = max(0, center_w - delta_w) + else: + start_h = 0 + start_w = 0 + if img_h > crop_h: + start_h = np.random.randint(img_h - crop_h + 1) + if img_w > crop_w: + start_w = np.random.randint(img_w - crop_w + 1) + + end_h = min(img_h, start_h + crop_h) + end_w = min(img_w, start_w + crop_w) + + data['img'] = data['img'][start_h:end_h, start_w:end_w] + for key in data.get('gt_fields', []): + data[key] = data[key][start_h:end_h, start_w:end_w] + + return data + + +@manager.TRANSFORMS.add_component +class RandomCrop: + """ + Randomly crop + + Args: + crop_size (tuple|list): The size you want to crop from image. + """ + + def __init__(self, crop_size=((320, 320), (480, 480), (640, 640))): + if not isinstance(crop_size[0], (list, tuple)): + crop_size = [crop_size] + self.crop_size = crop_size + + def __call__(self, data): + idex = np.random.randint(low=0, high=len(self.crop_size)) + crop_w, crop_h = self.crop_size[idex] + img_h, img_w = data['img'].shape[0:2] + + start_h = 0 + start_w = 0 + if img_h > crop_h: + start_h = np.random.randint(img_h - crop_h + 1) + if img_w > crop_w: + start_w = np.random.randint(img_w - crop_w + 1) + + end_h = min(img_h, start_h + crop_h) + end_w = min(img_w, start_w + crop_w) + + data['img'] = data['img'][start_h:end_h, start_w:end_w] + for key in data.get('gt_fields', []): + data[key] = data[key][start_h:end_h, start_w:end_w] + + return data + + +@manager.TRANSFORMS.add_component +class LimitLong: + """ + Limit the long edge of image. + + If the long edge is larger than max_long, resize the long edge + to max_long, while scale the short edge proportionally. + + If the long edge is smaller than min_long, resize the long edge + to min_long, while scale the short edge proportionally. + + Args: + max_long (int, optional): If the long edge of image is larger than max_long, + it will be resize to max_long. Default: None. + min_long (int, optional): If the long edge of image is smaller than min_long, + it will be resize to min_long. Default: None. + """ + + def __init__(self, max_long=None, min_long=None): + if max_long is not None: + if not isinstance(max_long, int): + raise TypeError( + "Type of `max_long` is invalid. It should be int, but it is {}" + .format(type(max_long))) + if min_long is not None: + if not isinstance(min_long, int): + raise TypeError( + "Type of `min_long` is invalid. It should be int, but it is {}" + .format(type(min_long))) + if (max_long is not None) and (min_long is not None): + if min_long > max_long: + raise ValueError( + '`max_long should not smaller than min_long, but they are {} and {}' + .format(max_long, min_long)) + self.max_long = max_long + self.min_long = min_long + + def __call__(self, data): + h, w = data['img'].shape[:2] + long_edge = max(h, w) + target = long_edge + if (self.max_long is not None) and (long_edge > self.max_long): + target = self.max_long + elif (self.min_long is not None) and (long_edge < self.min_long): + target = self.min_long + + data['trans_info'].append(('resize', data['img'].shape[0:2])) + if target != long_edge: + data['img'] = functional.resize_long(data['img'], target) + for key in data.get('gt_fields', []): + if key == 'trimap': + data[key] = functional.resize_long(data[key], target, + cv2.INTER_NEAREST) + else: + data[key] = functional.resize_long(data[key], target) + + return data + + +@manager.TRANSFORMS.add_component +class LimitShort: + """ + Limit the short edge of image. + + If the short edge is larger than max_short, resize the short edge + to max_short, while scale the long edge proportionally. + + If the short edge is smaller than min_short, resize the short edge + to min_short, while scale the long edge proportionally. + + Args: + max_short (int, optional): If the short edge of image is larger than max_short, + it will be resize to max_short. Default: None. + min_short (int, optional): If the short edge of image is smaller than min_short, + it will be resize to min_short. Default: None. + """ + + def __init__(self, max_short=None, min_short=None): + if max_short is not None: + if not isinstance(max_short, int): + raise TypeError( + "Type of `max_short` is invalid. It should be int, but it is {}" + .format(type(max_short))) + if min_short is not None: + if not isinstance(min_short, int): + raise TypeError( + "Type of `min_short` is invalid. It should be int, but it is {}" + .format(type(min_short))) + if (max_short is not None) and (min_short is not None): + if min_short > max_short: + raise ValueError( + '`max_short should not smaller than min_short, but they are {} and {}' + .format(max_short, min_short)) + self.max_short = max_short + self.min_short = min_short + + def __call__(self, data): + h, w = data['img'].shape[:2] + short_edge = min(h, w) + target = short_edge + if (self.max_short is not None) and (short_edge > self.max_short): + target = self.max_short + elif (self.min_short is not None) and (short_edge < self.min_short): + target = self.min_short + + data['trans_info'].append(('resize', data['img'].shape[0:2])) + if target != short_edge: + data['img'] = functional.resize_short(data['img'], target) + for key in data.get('gt_fields', []): + if key == 'trimap': + data[key] = functional.resize_short(data[key], target, + cv2.INTER_NEAREST) + else: + data[key] = functional.resize_short(data[key], target) + + return data + + +@manager.TRANSFORMS.add_component +class RandomHorizontalFlip: + """ + Flip an image horizontally with a certain probability. + + Args: + prob (float, optional): A probability of horizontally flipping. Default: 0.5. + """ + + def __init__(self, prob=0.5): + self.prob = prob + + def __call__(self, data): + if random.random() < self.prob: + data['img'] = functional.horizontal_flip(data['img']) + for key in data.get('gt_fields', []): + data[key] = functional.horizontal_flip(data[key]) + + return data + + +@manager.TRANSFORMS.add_component +class RandomBlur: + """ + Blurring an image by a Gaussian function with a certain probability. + + Args: + prob (float, optional): A probability of blurring an image. Default: 0.1. + """ + + def __init__(self, prob=0.1): + self.prob = prob + + def __call__(self, data): + if self.prob <= 0: + n = 0 + elif self.prob >= 1: + n = 1 + else: + n = int(1.0 / self.prob) + if n > 0: + if np.random.randint(0, n) == 0: + radius = np.random.randint(3, 10) + if radius % 2 != 1: + radius = radius + 1 + if radius > 9: + radius = 9 + data['img'] = cv2.GaussianBlur(data['img'], (radius, radius), 0, + 0) + for key in data.get('gt_fields', []): + if key == 'trimap': + continue + data[key] = cv2.GaussianBlur(data[key], (radius, radius), 0, + 0) + return data + + +@manager.TRANSFORMS.add_component +class RandomDistort: + """ + Distort an image with random configurations. + + Args: + brightness_range (float, optional): A range of brightness. Default: 0.5. + brightness_prob (float, optional): A probability of adjusting brightness. Default: 0.5. + contrast_range (float, optional): A range of contrast. Default: 0.5. + contrast_prob (float, optional): A probability of adjusting contrast. Default: 0.5. + saturation_range (float, optional): A range of saturation. Default: 0.5. + saturation_prob (float, optional): A probability of adjusting saturation. Default: 0.5. + hue_range (int, optional): A range of hue. Default: 18. + hue_prob (float, optional): A probability of adjusting hue. Default: 0.5. + """ + + def __init__(self, + brightness_range=0.5, + brightness_prob=0.5, + contrast_range=0.5, + contrast_prob=0.5, + saturation_range=0.5, + saturation_prob=0.5, + hue_range=18, + hue_prob=0.5): + self.brightness_range = brightness_range + self.brightness_prob = brightness_prob + self.contrast_range = contrast_range + self.contrast_prob = contrast_prob + self.saturation_range = saturation_range + self.saturation_prob = saturation_prob + self.hue_range = hue_range + self.hue_prob = hue_prob + + def __call__(self, data): + brightness_lower = 1 - self.brightness_range + brightness_upper = 1 + self.brightness_range + contrast_lower = 1 - self.contrast_range + contrast_upper = 1 + self.contrast_range + saturation_lower = 1 - self.saturation_range + saturation_upper = 1 + self.saturation_range + hue_lower = -self.hue_range + hue_upper = self.hue_range + ops = [ + functional.brightness, functional.contrast, functional.saturation, + functional.hue + ] + random.shuffle(ops) + params_dict = { + 'brightness': { + 'brightness_lower': brightness_lower, + 'brightness_upper': brightness_upper + }, + 'contrast': { + 'contrast_lower': contrast_lower, + 'contrast_upper': contrast_upper + }, + 'saturation': { + 'saturation_lower': saturation_lower, + 'saturation_upper': saturation_upper + }, + 'hue': { + 'hue_lower': hue_lower, + 'hue_upper': hue_upper + } + } + prob_dict = { + 'brightness': self.brightness_prob, + 'contrast': self.contrast_prob, + 'saturation': self.saturation_prob, + 'hue': self.hue_prob + } + + im = data['img'].astype('uint8') + im = Image.fromarray(im) + for id in range(len(ops)): + params = params_dict[ops[id].__name__] + params['im'] = im + prob = prob_dict[ops[id].__name__] + if np.random.uniform(0, 1) < prob: + im = ops[id](**params) + data['img'] = np.asarray(im) + + for key in data.get('gt_fields', []): + if key in ['alpha', 'trimap']: + continue + else: + im = data[key].astype('uint8') + im = Image.fromarray(im) + for id in range(len(ops)): + params = params_dict[ops[id].__name__] + params['im'] = im + prob = prob_dict[ops[id].__name__] + if np.random.uniform(0, 1) < prob: + im = ops[id](**params) + data[key] = np.asarray(im) + return data + + +@manager.TRANSFORMS.add_component +class Padding: + """ + Add bottom-right padding to a raw image or annotation image. + + Args: + target_size (list|tuple): The target size after padding. + im_padding_value (list, optional): The padding value of raw image. + Default: [127.5, 127.5, 127.5]. + label_padding_value (int, optional): The padding value of annotation image. Default: 255. + + Raises: + TypeError: When target_size is neither list nor tuple. + ValueError: When the length of target_size is not 2. + """ + + def __init__(self, target_size, im_padding_value=(127.5, 127.5, 127.5)): + if isinstance(target_size, list) or isinstance(target_size, tuple): + if len(target_size) != 2: + raise ValueError( + '`target_size` should include 2 elements, but it is {}'. + format(target_size)) + else: + raise TypeError( + "Type of target_size is invalid. It should be list or tuple, now is {}" + .format(type(target_size))) + + self.target_size = target_size + self.im_padding_value = im_padding_value + + def __call__(self, data): + im_height, im_width = data['img'].shape[0], data['img'].shape[1] + target_height = self.target_size[1] + target_width = self.target_size[0] + pad_height = max(0, target_height - im_height) + pad_width = max(0, target_width - im_width) + data['trans_info'].append(('padding', data['img'].shape[0:2])) + if (pad_height == 0) and (pad_width == 0): + return data + else: + data['img'] = cv2.copyMakeBorder( + data['img'], + 0, + pad_height, + 0, + pad_width, + cv2.BORDER_CONSTANT, + value=self.im_padding_value) + for key in data.get('gt_fields', []): + if key in ['trimap', 'alpha']: + value = 0 + else: + value = self.im_padding_value + data[key] = cv2.copyMakeBorder( + data[key], + 0, + pad_height, + 0, + pad_width, + cv2.BORDER_CONSTANT, + value=value) + return data + + +@manager.TRANSFORMS.add_component +class RandomSharpen: + def __init__(self, prob=0.1): + if prob < 0: + self.prob = 0 + elif prob > 1: + self.prob = 1 + else: + self.prob = prob + + def __call__(self, data): + if np.random.rand() > self.prob: + return data + + radius = np.random.choice([0, 3, 5, 7, 9]) + w = np.random.uniform(0.1, 0.5) + blur_img = cv2.GaussianBlur(data['img'], (radius, radius), 5) + data['img'] = cv2.addWeighted(data['img'], 1 + w, blur_img, -w, 0) + for key in data.get('gt_fields', []): + if key == 'trimap' or key == 'alpha': + continue + blur_img = cv2.GaussianBlur(data[key], (0, 0), 5) + data[key] = cv2.addWeighted(data[key], 1.5, blur_img, -0.5, 0) + + return data + + +@manager.TRANSFORMS.add_component +class RandomNoise: + def __init__(self, prob=0.1): + if prob < 0: + self.prob = 0 + elif prob > 1: + self.prob = 1 + else: + self.prob = prob + + def __call__(self, data): + if np.random.rand() > self.prob: + return data + mean = np.random.uniform(0, 0.04) + var = np.random.uniform(0, 0.001) + noise = np.random.normal(mean, var**0.5, data['img'].shape) * 255 + data['img'] = data['img'] + noise + data['img'] = np.clip(data['img'], 0, 255) + + return data + + +@manager.TRANSFORMS.add_component +class RandomReJpeg: + def __init__(self, prob=0.1): + if prob < 0: + self.prob = 0 + elif prob > 1: + self.prob = 1 + else: + self.prob = prob + + def __call__(self, data): + if np.random.rand() > self.prob: + return data + q = np.random.randint(70, 95) + img = data['img'].astype('uint8') + + # Ensure no conflicts between processes + tmp_name = str(os.getpid()) + '.jpg' + tmp_name = os.path.join(seg_env.TMP_HOME, tmp_name) + cv2.imwrite(tmp_name, img, [int(cv2.IMWRITE_JPEG_QUALITY), q]) + data['img'] = cv2.imread(tmp_name) + + return data diff --git a/ppmatting/utils/__init__.py b/ppmatting/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..79717c71036b5b730cce8548bc27f6fef7222c21 --- /dev/null +++ b/ppmatting/utils/__init__.py @@ -0,0 +1,2 @@ +from .estimate_foreground_ml import estimate_foreground_ml +from .utils import get_files, get_image_list, mkdir diff --git a/ppmatting/utils/estimate_foreground_ml.py b/ppmatting/utils/estimate_foreground_ml.py new file mode 100644 index 0000000000000000000000000000000000000000..05bffb6c31a5042fd96c028013c81f7533f3675d --- /dev/null +++ b/ppmatting/utils/estimate_foreground_ml.py @@ -0,0 +1,236 @@ +import numpy as np +from numba import njit, prange + +# The foreground estimation refer to pymatting [https://github.com/pymatting/pymatting/blob/master/pymatting/foreground/estimate_foreground_ml.py] + + +@njit("void(f4[:, :, :], f4[:, :, :])", cache=True, nogil=True, parallel=True) +def _resize_nearest_multichannel(dst, src): + """ + Internal method. + + Resize image src to dst using nearest neighbors filtering. + Images must have multiple color channels, i.e. :code:`len(shape) == 3`. + + Parameters + ---------- + dst: numpy.ndarray of type np.float32 + output image + src: numpy.ndarray of type np.float32 + input image + """ + h_src, w_src, depth = src.shape + h_dst, w_dst, depth = dst.shape + + for y_dst in prange(h_dst): + for x_dst in range(w_dst): + x_src = max(0, min(w_src - 1, x_dst * w_src // w_dst)) + y_src = max(0, min(h_src - 1, y_dst * h_src // h_dst)) + + for c in range(depth): + dst[y_dst, x_dst, c] = src[y_src, x_src, c] + + +@njit("void(f4[:, :], f4[:, :])", cache=True, nogil=True, parallel=True) +def _resize_nearest(dst, src): + """ + Internal method. + + Resize image src to dst using nearest neighbors filtering. + Images must be grayscale, i.e. :code:`len(shape) == 3`. + + Parameters + ---------- + dst: numpy.ndarray of type np.float32 + output image + src: numpy.ndarray of type np.float32 + input image + """ + h_src, w_src = src.shape + h_dst, w_dst = dst.shape + + for y_dst in prange(h_dst): + for x_dst in range(w_dst): + x_src = max(0, min(w_src - 1, x_dst * w_src // w_dst)) + y_src = max(0, min(h_src - 1, y_dst * h_src // h_dst)) + + dst[y_dst, x_dst] = src[y_src, x_src] + + +# TODO +# There should be an option to switch @njit(parallel=True) on or off. +# parallel=True would be faster, but might cause race conditions. +# User should have the option to turn it on or off. +@njit( + "Tuple((f4[:, :, :], f4[:, :, :]))(f4[:, :, :], f4[:, :], f4, i4, i4, i4, f4)", + cache=True, + nogil=True) +def _estimate_fb_ml( + input_image, + input_alpha, + regularization, + n_small_iterations, + n_big_iterations, + small_size, + gradient_weight, ): + h0, w0, depth = input_image.shape + + dtype = np.float32 + + w_prev = 1 + h_prev = 1 + + F_prev = np.empty((h_prev, w_prev, depth), dtype=dtype) + B_prev = np.empty((h_prev, w_prev, depth), dtype=dtype) + + n_levels = int(np.ceil(np.log2(max(w0, h0)))) + + for i_level in range(n_levels + 1): + w = round(w0**(i_level / n_levels)) + h = round(h0**(i_level / n_levels)) + + image = np.empty((h, w, depth), dtype=dtype) + alpha = np.empty((h, w), dtype=dtype) + + _resize_nearest_multichannel(image, input_image) + _resize_nearest(alpha, input_alpha) + + F = np.empty((h, w, depth), dtype=dtype) + B = np.empty((h, w, depth), dtype=dtype) + + _resize_nearest_multichannel(F, F_prev) + _resize_nearest_multichannel(B, B_prev) + + if w <= small_size and h <= small_size: + n_iter = n_small_iterations + else: + n_iter = n_big_iterations + + b = np.zeros((2, depth), dtype=dtype) + + dx = [-1, 1, 0, 0] + dy = [0, 0, -1, 1] + + for i_iter in range(n_iter): + for y in prange(h): + for x in range(w): + a0 = alpha[y, x] + a1 = 1.0 - a0 + + a00 = a0 * a0 + a01 = a0 * a1 + # a10 = a01 can be omitted due to symmetry of matrix + a11 = a1 * a1 + + for c in range(depth): + b[0, c] = a0 * image[y, x, c] + b[1, c] = a1 * image[y, x, c] + + for d in range(4): + x2 = max(0, min(w - 1, x + dx[d])) + y2 = max(0, min(h - 1, y + dy[d])) + + gradient = abs(a0 - alpha[y2, x2]) + + da = regularization + gradient_weight * gradient + + a00 += da + a11 += da + + for c in range(depth): + b[0, c] += da * F[y2, x2, c] + b[1, c] += da * B[y2, x2, c] + + determinant = a00 * a11 - a01 * a01 + + inv_det = 1.0 / determinant + + b00 = inv_det * a11 + b01 = inv_det * -a01 + b11 = inv_det * a00 + + for c in range(depth): + F_c = b00 * b[0, c] + b01 * b[1, c] + B_c = b01 * b[0, c] + b11 * b[1, c] + + F_c = max(0.0, min(1.0, F_c)) + B_c = max(0.0, min(1.0, B_c)) + + F[y, x, c] = F_c + B[y, x, c] = B_c + + F_prev = F + B_prev = B + + w_prev = w + h_prev = h + + return F, B + + +def estimate_foreground_ml( + image, + alpha, + regularization=1e-5, + n_small_iterations=10, + n_big_iterations=2, + small_size=32, + return_background=False, + gradient_weight=1.0, ): + """Estimates the foreground of an image given its alpha matte. + + See :cite:`germer2020multilevel` for reference. + + Parameters + ---------- + image: numpy.ndarray + Input image with shape :math:`h \\times w \\times d` + alpha: numpy.ndarray + Input alpha matte shape :math:`h \\times w` + regularization: float + Regularization strength :math:`\\epsilon`, defaults to :math:`10^{-5}`. + Higher regularization results in smoother colors. + n_small_iterations: int + Number of iterations performed on small scale, defaults to :math:`10` + n_big_iterations: int + Number of iterations performed on large scale, defaults to :math:`2` + small_size: int + Threshold that determines at which size `n_small_iterations` should be used + return_background: bool + Whether to return the estimated background in addition to the foreground + gradient_weight: float + Larger values enforce smoother foregrounds, defaults to :math:`1` + + Returns + ------- + F: numpy.ndarray + Extracted foreground + B: numpy.ndarray + Extracted background + + Example + ------- + >>> from pymatting import * + >>> image = load_image("data/lemur/lemur.png", "RGB") + >>> alpha = load_image("data/lemur/lemur_alpha.png", "GRAY") + >>> F = estimate_foreground_ml(image, alpha, return_background=False) + >>> F, B = estimate_foreground_ml(image, alpha, return_background=True) + + See Also + ---- + stack_images: This function can be used to place the foreground on a new background. + """ + + foreground, background = _estimate_fb_ml( + image.astype(np.float32), + alpha.astype(np.float32), + regularization, + n_small_iterations, + n_big_iterations, + small_size, + gradient_weight, ) + + if return_background: + return foreground, background + + return foreground diff --git a/ppmatting/utils/utils.py b/ppmatting/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..13513cb193757b63043f44a2c145b3e9b6fad82e --- /dev/null +++ b/ppmatting/utils/utils.py @@ -0,0 +1,71 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + + +def get_files(root_path): + res = [] + for root, dirs, files in os.walk(root_path, followlinks=True): + for f in files: + if f.endswith(('.jpg', '.png', '.jpeg', 'JPG')): + res.append(os.path.join(root, f)) + return res + + +def get_image_list(image_path): + """Get image list""" + valid_suffix = [ + '.JPEG', '.jpeg', '.JPG', '.jpg', '.BMP', '.bmp', '.PNG', '.png' + ] + image_list = [] + image_dir = None + if os.path.isfile(image_path): + image_dir = None + if os.path.splitext(image_path)[-1] in valid_suffix: + image_list.append(image_path) + else: + image_dir = os.path.dirname(image_path) + with open(image_path, 'r') as f: + for line in f: + line = line.strip() + if len(line.split()) > 1: + raise RuntimeError( + 'There should be only one image path per line in `image_path` file. Wrong line: {}' + .format(line)) + image_list.append(os.path.join(image_dir, line)) + elif os.path.isdir(image_path): + image_dir = image_path + for root, dirs, files in os.walk(image_path): + for f in files: + if '.ipynb_checkpoints' in root: + continue + if os.path.splitext(f)[-1] in valid_suffix: + image_list.append(os.path.join(root, f)) + image_list.sort() + else: + raise FileNotFoundError( + '`image_path` is not found. it should be an image file or a directory including images' + ) + + if len(image_list) == 0: + raise RuntimeError('There are not image file in `image_path`') + + return image_list, image_dir + + +def mkdir(path): + sub_dir = os.path.dirname(path) + if not os.path.exists(sub_dir): + os.makedirs(sub_dir) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..977b70ee41bf37405fb2d75d98cf4e33ccaee972 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +opencv-python>=4.6.0 +numpy>=1.23.5 +Pillow>=9.3.0 +paddlepaddle==2.4.0 +paddleseg>=2.6.0 +scikit-learn>=1.1.3 +pymatting>=1.1.8 +scikit-image>=0.19.3