Spaces:
Runtime error
Runtime error
update .gitignore
Browse files- .gitattributes +1 -0
- .gitignore +5 -0
- app.py +23 -0
- download.sh +11 -0
- examples/SOD001.jpg +0 -0
- examples/SOD003.jpeg +0 -0
- examples/SOD013.jpg +0 -0
- examples/SOD015.jpg +0 -0
- inpaint/configs/prediction/default.yaml +14 -0
- inpaint/infer_model.py +94 -0
- inpaint/predict.py +96 -0
- inpaint/saicinpainting/training/modules/__init__.py +7 -0
- inpaint/saicinpainting/training/modules/base.py +80 -0
- inpaint/saicinpainting/training/modules/depthwise_sep_conv.py +17 -0
- inpaint/saicinpainting/training/modules/fake_fakes.py +47 -0
- inpaint/saicinpainting/training/modules/ffc.py +367 -0
- inpaint/saicinpainting/training/modules/multidilated_conv.py +98 -0
- inpaint/saicinpainting/training/modules/multiscale.py +244 -0
- inpaint/saicinpainting/training/modules/pix2pixhd.py +669 -0
- inpaint/saicinpainting/training/modules/spatial_transform.py +49 -0
- inpaint/saicinpainting/training/modules/squeeze_excitation.py +20 -0
- inpaint/saicinpainting/training/trainers/__init__.py +26 -0
- inpaint/saicinpainting/training/trainers/base.py +19 -0
- inpaint/saicinpainting/training/trainers/default.py +53 -0
- sod/PGNet.py +270 -0
- sod/Res.py +363 -0
- sod/Swin.py +578 -0
- sod/configs/prediction/default.yaml +14 -0
- sod/infer_model.py +89 -0
.gitattributes
CHANGED
@@ -25,3 +25,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
inpaint/weights
|
2 |
+
sod/weights
|
3 |
+
**/__pycache__
|
4 |
+
flagged
|
5 |
+
**/*.zip
|
app.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from doctest import Example
|
2 |
+
import gradio as gr
|
3 |
+
import inpaint.infer_model as inpaint
|
4 |
+
import sod.infer_model as sod
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import os
|
8 |
+
# cmd = 'sh download.sh'
|
9 |
+
# os.system(cmd)
|
10 |
+
|
11 |
+
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
12 |
+
inpaint_model = inpaint.IVModel(device=device)
|
13 |
+
sod_model = sod.IVModel(device=torch.device("cpu"))
|
14 |
+
def sod_inpaint(img):
|
15 |
+
img = img[:,:,::-1]
|
16 |
+
res = sod_model.forward(img,None)
|
17 |
+
res = np.uint8(res)
|
18 |
+
res = inpaint_model.forward(res,None)
|
19 |
+
res = np.uint8(res)
|
20 |
+
return res[:,:,::-1]
|
21 |
+
|
22 |
+
iface = gr.Interface(fn=sod_inpaint, inputs="image", outputs="image", examples='examples', title='显著物体消除', description='这是一个图像API,功能是自动把画面中的显著物体消除', theme='huggingface')
|
23 |
+
iface.launch(server_name='0.0.0.0', share=False)
|
download.sh
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FILE_ID=1udSLeuWAZf2-uI7SI8dFEfBuyvLSpB9V
|
2 |
+
checkpoint_path='inpaint/weights.zip'
|
3 |
+
wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate "https://docs.google.com/uc?export=download&id=${FILE_ID}" -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/
|
4 |
+
/p')&id=${FILE_ID}" -O ${checkpoint_path} && rm -rf /tmp/cookies.txt
|
5 |
+
unzip $checkpoint_path
|
6 |
+
|
7 |
+
FILE_ID=1qI8-HBTz2nNSTyD9iB4YMn067P7XMKuD
|
8 |
+
checkpoint_path='sod/weights.zip'
|
9 |
+
wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate "https://docs.google.com/uc?export=download&id=${FILE_ID}" -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/
|
10 |
+
/p')&id=${FILE_ID}" -O ${checkpoint_path} && rm -rf /tmp/cookies.txt
|
11 |
+
unzip $checkpoint_path
|
examples/SOD001.jpg
ADDED
![]() |
examples/SOD003.jpeg
ADDED
![]() |
examples/SOD013.jpg
ADDED
![]() |
examples/SOD015.jpg
ADDED
![]() |
inpaint/configs/prediction/default.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
indir: no # to be overriden in CLI
|
2 |
+
outdir: no # to be overriden in CLI
|
3 |
+
|
4 |
+
model:
|
5 |
+
path: no # to be overriden in CLI
|
6 |
+
checkpoint: best.ckpt
|
7 |
+
|
8 |
+
dataset:
|
9 |
+
kind: default
|
10 |
+
img_suffix: .png
|
11 |
+
pad_out_to_modulo: 8
|
12 |
+
|
13 |
+
device: cuda
|
14 |
+
out_key: inpainted
|
inpaint/infer_model.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
import os
|
8 |
+
import sys
|
9 |
+
sys.path.append('inpaint')
|
10 |
+
os.environ['OMP_NUM_THREADS'] = '1'
|
11 |
+
os.environ['OPENBLAS_NUM_THREADS'] = '1'
|
12 |
+
os.environ['MKL_NUM_THREADS'] = '1'
|
13 |
+
os.environ['VECLIB_MAXIMUM_THREADS'] = '1'
|
14 |
+
os.environ['NUMEXPR_NUM_THREADS'] = '1'
|
15 |
+
|
16 |
+
import cv2
|
17 |
+
import numpy as np
|
18 |
+
import torch
|
19 |
+
import yaml
|
20 |
+
from omegaconf import OmegaConf
|
21 |
+
from saicinpainting.training.trainers import load_checkpoint
|
22 |
+
|
23 |
+
|
24 |
+
class IVModel():
|
25 |
+
def __init__(self, device=torch.device('cuda:0')):
|
26 |
+
super(IVModel, self).__init__()
|
27 |
+
self.device = device
|
28 |
+
conf_path = 'inpaint/configs/prediction/default.yaml'
|
29 |
+
predict_config = OmegaConf.load(conf_path)
|
30 |
+
predict_config.model.path='inpaint/weights/big-lama'
|
31 |
+
if not os.path.exists(conf_path):
|
32 |
+
print('未找到配置文件!')
|
33 |
+
train_config_path = os.path.join(predict_config.model.path, 'config.yaml')
|
34 |
+
with open(train_config_path, 'r') as f:
|
35 |
+
train_config = OmegaConf.create(yaml.safe_load(f))
|
36 |
+
|
37 |
+
train_config.training_model.predict_only = True
|
38 |
+
train_config.visualizer.kind = 'noop'
|
39 |
+
|
40 |
+
checkpoint_path = os.path.join(predict_config.model.path, 'models', predict_config.model.checkpoint)
|
41 |
+
self.model = load_checkpoint(train_config, checkpoint_path, strict=False, map_location='cpu')
|
42 |
+
self.model.freeze()
|
43 |
+
self.model.to(device)
|
44 |
+
|
45 |
+
self.__first_forward__()
|
46 |
+
|
47 |
+
|
48 |
+
def __first_forward__(self, input_size=(2048, 4096, 3)):
|
49 |
+
# 调用forward()严格控制最大显存
|
50 |
+
print('initialize Inpaint Model...')
|
51 |
+
_ = self.forward(np.random.rand(*input_size) * 255, None)
|
52 |
+
print('initialize Complete!')
|
53 |
+
|
54 |
+
def __resize_tensor__(self, image, max_size=1024, scale_factor=8):
|
55 |
+
h, w = image.size()[2:]
|
56 |
+
if max(h, w) > max_size:
|
57 |
+
if h < w:
|
58 |
+
h, w = int(max_size * h / w), max_size
|
59 |
+
else:
|
60 |
+
h, w = max_size, int(max_size * w / h)
|
61 |
+
h = h // scale_factor * scale_factor
|
62 |
+
w = w // scale_factor * scale_factor
|
63 |
+
image = F.interpolate(image, (h, w), mode='bicubic')
|
64 |
+
return image
|
65 |
+
|
66 |
+
def input_preprocess_tensor(self, img):
|
67 |
+
img_t = torch.from_numpy(img.astype(np.float32)) # .to(self.device)
|
68 |
+
img_t = img_t.permute(2, 0, 1).unsqueeze(0)
|
69 |
+
img_t = img_t / 255.
|
70 |
+
img_t_for_net = self.__resize_tensor__(img_t).to(self.device) # 为了控制最大显存容量
|
71 |
+
img_t_for_out = self.__resize_tensor__(img_t, max_size=2048).to(self.device) # 为了控制最大显存容量
|
72 |
+
return img_t_for_net, img_t_for_out
|
73 |
+
|
74 |
+
def forward(self, img, json_data):
|
75 |
+
_,w,_ = img.shape
|
76 |
+
mask = img[:,w//2:,0]
|
77 |
+
kernel = np.ones((4, 4), np.uint8)
|
78 |
+
mask = cv2.dilate(mask, kernel, iterations=5)[:,:,np.newaxis]
|
79 |
+
img = img[:,:w//2,::-1]
|
80 |
+
# print(img.shape,mask.shape)
|
81 |
+
img_t_for_net, img_t_for_out = self.input_preprocess_tensor(img)
|
82 |
+
mask_t_for_net, mask_t_for_out = self.input_preprocess_tensor(mask)
|
83 |
+
h, w = img_t_for_out.shape[2:]
|
84 |
+
# print(img_t.shape,mask_t.shape)
|
85 |
+
mask_t_for_out = (mask_t_for_out>0).int()
|
86 |
+
mask_t_for_net = (mask_t_for_net>0).int()
|
87 |
+
with torch.no_grad():
|
88 |
+
res_t = self.model(dict(image=img_t_for_net, mask=mask_t_for_net))['inpainted']
|
89 |
+
res_t = F.interpolate(res_t, (h, w), mode='bicubic')
|
90 |
+
res_t = img_t_for_out * (1 - mask_t_for_out) + res_t * mask_t_for_out
|
91 |
+
res_t = torch.clip(res_t * 255, min=0, max=255)
|
92 |
+
res = res_t.squeeze(0).permute(1, 2, 0).detach().cpu().numpy()
|
93 |
+
res = res[:, :, (2, 1, 0)].astype(np.uint8)
|
94 |
+
return res
|
inpaint/predict.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
# Example command:
|
4 |
+
# ./bin/predict.py \
|
5 |
+
# model.path=<path to checkpoint, prepared by make_checkpoint.py> \
|
6 |
+
# indir=<path to input data> \
|
7 |
+
# outdir=<where to store predicts>
|
8 |
+
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
import traceback
|
13 |
+
|
14 |
+
from saicinpainting.evaluation.utils import move_to_device
|
15 |
+
|
16 |
+
os.environ['OMP_NUM_THREADS'] = '1'
|
17 |
+
os.environ['OPENBLAS_NUM_THREADS'] = '1'
|
18 |
+
os.environ['MKL_NUM_THREADS'] = '1'
|
19 |
+
os.environ['VECLIB_MAXIMUM_THREADS'] = '1'
|
20 |
+
os.environ['NUMEXPR_NUM_THREADS'] = '1'
|
21 |
+
|
22 |
+
import cv2
|
23 |
+
# import hydra
|
24 |
+
import numpy as np
|
25 |
+
import torch
|
26 |
+
import tqdm
|
27 |
+
import yaml
|
28 |
+
from omegaconf import OmegaConf
|
29 |
+
from torch.utils.data._utils.collate import default_collate
|
30 |
+
|
31 |
+
from saicinpainting.training.data.datasets import make_default_val_dataset
|
32 |
+
from saicinpainting.training.trainers import load_checkpoint
|
33 |
+
from saicinpainting.utils import register_debug_signal_handlers
|
34 |
+
|
35 |
+
LOGGER = logging.getLogger(__name__)
|
36 |
+
|
37 |
+
|
38 |
+
# @hydra.main(config_path='../configs/prediction', config_name='default.yaml')
|
39 |
+
def main(predict_config: OmegaConf):
|
40 |
+
try:
|
41 |
+
register_debug_signal_handlers() # kill -10 <pid> will result in traceback dumped into log
|
42 |
+
|
43 |
+
device = torch.device(predict_config.device)
|
44 |
+
print(predict_config)
|
45 |
+
train_config_path = os.path.join(predict_config.model.path, 'config.yaml')
|
46 |
+
with open(train_config_path, 'r') as f:
|
47 |
+
train_config = OmegaConf.create(yaml.safe_load(f))
|
48 |
+
|
49 |
+
train_config.training_model.predict_only = True
|
50 |
+
train_config.visualizer.kind = 'noop'
|
51 |
+
|
52 |
+
out_ext = predict_config.get('out_ext', '.png')
|
53 |
+
|
54 |
+
checkpoint_path = os.path.join(predict_config.model.path,
|
55 |
+
'models',
|
56 |
+
predict_config.model.checkpoint)
|
57 |
+
model = load_checkpoint(train_config, checkpoint_path, strict=False, map_location='cpu')
|
58 |
+
model.freeze()
|
59 |
+
model.to(device)
|
60 |
+
|
61 |
+
if not predict_config.indir.endswith('/'):
|
62 |
+
predict_config.indir += '/'
|
63 |
+
|
64 |
+
dataset = make_default_val_dataset(predict_config.indir, **predict_config.dataset)
|
65 |
+
with torch.no_grad():
|
66 |
+
for img_i in tqdm.trange(len(dataset)):
|
67 |
+
mask_fname = dataset.mask_filenames[img_i]
|
68 |
+
cur_out_fname = os.path.join(
|
69 |
+
predict_config.outdir,
|
70 |
+
os.path.splitext(os.path.basename(mask_fname))[0] + out_ext
|
71 |
+
)
|
72 |
+
os.makedirs(os.path.dirname(cur_out_fname), exist_ok=True)
|
73 |
+
|
74 |
+
batch = move_to_device(default_collate([dataset[img_i]]), device)
|
75 |
+
batch['mask'] = (batch['mask'] > 0) * 1
|
76 |
+
# print(torch.max(batch['mask']), torch.min(batch['mask']), torch.max(batch['image']), torch.min(batch['image']))
|
77 |
+
print(batch['mask'].dtype)
|
78 |
+
batch = model(batch)
|
79 |
+
cur_res = batch[predict_config.out_key][0].permute(1, 2, 0).detach().cpu().numpy()
|
80 |
+
|
81 |
+
cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
|
82 |
+
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
|
83 |
+
cv2.imwrite(cur_out_fname, cur_res)
|
84 |
+
except KeyboardInterrupt:
|
85 |
+
LOGGER.warning('Interrupted by user')
|
86 |
+
except Exception as ex:
|
87 |
+
LOGGER.critical(f'Prediction failed due to {ex}:\n{traceback.format_exc()}')
|
88 |
+
sys.exit(1)
|
89 |
+
|
90 |
+
|
91 |
+
if __name__ == '__main__':
|
92 |
+
base_conf = OmegaConf.load('configs/prediction/default.yaml')
|
93 |
+
base_conf.model.path='../../weights/big-lama'
|
94 |
+
base_conf.indir='/home/zwp/Temp2018/ZWP/data/example/照片补全/测试图片/带mask的图片2'
|
95 |
+
base_conf.outdir='zb_result/带mask的图片2'
|
96 |
+
main(predict_config=base_conf)
|
inpaint/saicinpainting/training/modules/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from saicinpainting.training.modules.ffc import FFCResNetGenerator
|
2 |
+
|
3 |
+
def make_generator(config, kind, **kwargs):
|
4 |
+
return FFCResNetGenerator(**kwargs)
|
5 |
+
|
6 |
+
|
7 |
+
|
inpaint/saicinpainting/training/modules/base.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
from typing import Tuple, List
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from saicinpainting.training.modules.depthwise_sep_conv import DepthWiseSeperableConv
|
8 |
+
from saicinpainting.training.modules.multidilated_conv import MultidilatedConv
|
9 |
+
|
10 |
+
|
11 |
+
class BaseDiscriminator(nn.Module):
|
12 |
+
@abc.abstractmethod
|
13 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
14 |
+
"""
|
15 |
+
Predict scores and get intermediate activations. Useful for feature matching loss
|
16 |
+
:return tuple (scores, list of intermediate activations)
|
17 |
+
"""
|
18 |
+
raise NotImplemented()
|
19 |
+
|
20 |
+
|
21 |
+
def get_conv_block_ctor(kind='default'):
|
22 |
+
if not isinstance(kind, str):
|
23 |
+
return kind
|
24 |
+
if kind == 'default':
|
25 |
+
return nn.Conv2d
|
26 |
+
if kind == 'depthwise':
|
27 |
+
return DepthWiseSeperableConv
|
28 |
+
if kind == 'multidilated':
|
29 |
+
return MultidilatedConv
|
30 |
+
raise ValueError(f'Unknown convolutional block kind {kind}')
|
31 |
+
|
32 |
+
|
33 |
+
def get_norm_layer(kind='bn'):
|
34 |
+
if not isinstance(kind, str):
|
35 |
+
return kind
|
36 |
+
if kind == 'bn':
|
37 |
+
return nn.BatchNorm2d
|
38 |
+
if kind == 'in':
|
39 |
+
return nn.InstanceNorm2d
|
40 |
+
raise ValueError(f'Unknown norm block kind {kind}')
|
41 |
+
|
42 |
+
|
43 |
+
def get_activation(kind='tanh'):
|
44 |
+
if kind == 'tanh':
|
45 |
+
return nn.Tanh()
|
46 |
+
if kind == 'sigmoid':
|
47 |
+
return nn.Sigmoid()
|
48 |
+
if kind is False:
|
49 |
+
return nn.Identity()
|
50 |
+
raise ValueError(f'Unknown activation kind {kind}')
|
51 |
+
|
52 |
+
|
53 |
+
class SimpleMultiStepGenerator(nn.Module):
|
54 |
+
def __init__(self, steps: List[nn.Module]):
|
55 |
+
super().__init__()
|
56 |
+
self.steps = nn.ModuleList(steps)
|
57 |
+
|
58 |
+
def forward(self, x):
|
59 |
+
cur_in = x
|
60 |
+
outs = []
|
61 |
+
for step in self.steps:
|
62 |
+
cur_out = step(cur_in)
|
63 |
+
outs.append(cur_out)
|
64 |
+
cur_in = torch.cat((cur_in, cur_out), dim=1)
|
65 |
+
return torch.cat(outs[::-1], dim=1)
|
66 |
+
|
67 |
+
def deconv_factory(kind, ngf, mult, norm_layer, activation, max_features):
|
68 |
+
if kind == 'convtranspose':
|
69 |
+
return [nn.ConvTranspose2d(min(max_features, ngf * mult),
|
70 |
+
min(max_features, int(ngf * mult / 2)),
|
71 |
+
kernel_size=3, stride=2, padding=1, output_padding=1),
|
72 |
+
norm_layer(min(max_features, int(ngf * mult / 2))), activation]
|
73 |
+
elif kind == 'bilinear':
|
74 |
+
return [nn.Upsample(scale_factor=2, mode='bilinear'),
|
75 |
+
DepthWiseSeperableConv(min(max_features, ngf * mult),
|
76 |
+
min(max_features, int(ngf * mult / 2)),
|
77 |
+
kernel_size=3, stride=1, padding=1),
|
78 |
+
norm_layer(min(max_features, int(ngf * mult / 2))), activation]
|
79 |
+
else:
|
80 |
+
raise Exception(f"Invalid deconv kind: {kind}")
|
inpaint/saicinpainting/training/modules/depthwise_sep_conv.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class DepthWiseSeperableConv(nn.Module):
|
5 |
+
def __init__(self, in_dim, out_dim, *args, **kwargs):
|
6 |
+
super().__init__()
|
7 |
+
if 'groups' in kwargs:
|
8 |
+
# ignoring groups for Depthwise Sep Conv
|
9 |
+
del kwargs['groups']
|
10 |
+
|
11 |
+
self.depthwise = nn.Conv2d(in_dim, in_dim, *args, groups=in_dim, **kwargs)
|
12 |
+
self.pointwise = nn.Conv2d(in_dim, out_dim, kernel_size=1)
|
13 |
+
|
14 |
+
def forward(self, x):
|
15 |
+
out = self.depthwise(x)
|
16 |
+
out = self.pointwise(out)
|
17 |
+
return out
|
inpaint/saicinpainting/training/modules/fake_fakes.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from kornia import SamplePadding
|
3 |
+
from kornia.augmentation import RandomAffine, CenterCrop
|
4 |
+
|
5 |
+
|
6 |
+
class FakeFakesGenerator:
|
7 |
+
def __init__(self, aug_proba=0.5, img_aug_degree=30, img_aug_translate=0.2):
|
8 |
+
self.grad_aug = RandomAffine(degrees=360,
|
9 |
+
translate=0.2,
|
10 |
+
padding_mode=SamplePadding.REFLECTION,
|
11 |
+
keepdim=False,
|
12 |
+
p=1)
|
13 |
+
self.img_aug = RandomAffine(degrees=img_aug_degree,
|
14 |
+
translate=img_aug_translate,
|
15 |
+
padding_mode=SamplePadding.REFLECTION,
|
16 |
+
keepdim=True,
|
17 |
+
p=1)
|
18 |
+
self.aug_proba = aug_proba
|
19 |
+
|
20 |
+
def __call__(self, input_images, masks):
|
21 |
+
blend_masks = self._fill_masks_with_gradient(masks)
|
22 |
+
blend_target = self._make_blend_target(input_images)
|
23 |
+
result = input_images * (1 - blend_masks) + blend_target * blend_masks
|
24 |
+
return result, blend_masks
|
25 |
+
|
26 |
+
def _make_blend_target(self, input_images):
|
27 |
+
batch_size = input_images.shape[0]
|
28 |
+
permuted = input_images[torch.randperm(batch_size)]
|
29 |
+
augmented = self.img_aug(input_images)
|
30 |
+
is_aug = (torch.rand(batch_size, device=input_images.device)[:, None, None, None] < self.aug_proba).float()
|
31 |
+
result = augmented * is_aug + permuted * (1 - is_aug)
|
32 |
+
return result
|
33 |
+
|
34 |
+
def _fill_masks_with_gradient(self, masks):
|
35 |
+
batch_size, _, height, width = masks.shape
|
36 |
+
grad = torch.linspace(0, 1, steps=width * 2, device=masks.device, dtype=masks.dtype) \
|
37 |
+
.view(1, 1, 1, -1).expand(batch_size, 1, height * 2, width * 2)
|
38 |
+
grad = self.grad_aug(grad)
|
39 |
+
grad = CenterCrop((height, width))(grad)
|
40 |
+
grad *= masks
|
41 |
+
|
42 |
+
grad_for_min = grad + (1 - masks) * 10
|
43 |
+
grad -= grad_for_min.view(batch_size, -1).min(-1).values[:, None, None, None]
|
44 |
+
grad /= grad.view(batch_size, -1).max(-1).values[:, None, None, None] + 1e-6
|
45 |
+
grad.clamp_(min=0, max=1)
|
46 |
+
|
47 |
+
return grad
|
inpaint/saicinpainting/training/modules/ffc.py
ADDED
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Fast Fourier Convolution NeurIPS 2020
|
2 |
+
# original implementation https://github.com/pkumivision/FFC/blob/main/model_zoo/ffc.py
|
3 |
+
# paper https://proceedings.neurips.cc/paper/2020/file/2fd5d41ec6cfab47e32164d5624269b1-Paper.pdf
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from saicinpainting.training.modules.base import get_activation
|
10 |
+
from saicinpainting.training.modules.spatial_transform import LearnableSpatialTransformWrapper
|
11 |
+
from saicinpainting.training.modules.squeeze_excitation import SELayer
|
12 |
+
|
13 |
+
|
14 |
+
class FFCSE_block(nn.Module):
|
15 |
+
|
16 |
+
def __init__(self, channels, ratio_g):
|
17 |
+
super(FFCSE_block, self).__init__()
|
18 |
+
in_cg = int(channels * ratio_g)
|
19 |
+
in_cl = channels - in_cg
|
20 |
+
r = 16
|
21 |
+
|
22 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
23 |
+
self.conv1 = nn.Conv2d(channels, channels // r,
|
24 |
+
kernel_size=1, bias=True)
|
25 |
+
self.relu1 = nn.ReLU(inplace=True)
|
26 |
+
self.conv_a2l = None if in_cl == 0 else nn.Conv2d(
|
27 |
+
channels // r, in_cl, kernel_size=1, bias=True)
|
28 |
+
self.conv_a2g = None if in_cg == 0 else nn.Conv2d(
|
29 |
+
channels // r, in_cg, kernel_size=1, bias=True)
|
30 |
+
self.sigmoid = nn.Sigmoid()
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
x = x if type(x) is tuple else (x, 0)
|
34 |
+
id_l, id_g = x
|
35 |
+
|
36 |
+
x = id_l if type(id_g) is int else torch.cat([id_l, id_g], dim=1)
|
37 |
+
x = self.avgpool(x)
|
38 |
+
x = self.relu1(self.conv1(x))
|
39 |
+
|
40 |
+
x_l = 0 if self.conv_a2l is None else id_l * \
|
41 |
+
self.sigmoid(self.conv_a2l(x))
|
42 |
+
x_g = 0 if self.conv_a2g is None else id_g * \
|
43 |
+
self.sigmoid(self.conv_a2g(x))
|
44 |
+
return x_l, x_g
|
45 |
+
|
46 |
+
|
47 |
+
class FourierUnit(nn.Module):
|
48 |
+
|
49 |
+
def __init__(self, in_channels, out_channels, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear',
|
50 |
+
spectral_pos_encoding=False, use_se=False, se_kwargs=None, ffc3d=False, fft_norm='ortho'):
|
51 |
+
# bn_layer not used
|
52 |
+
super(FourierUnit, self).__init__()
|
53 |
+
self.groups = groups
|
54 |
+
|
55 |
+
self.conv_layer = torch.nn.Conv2d(in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0),
|
56 |
+
out_channels=out_channels * 2,
|
57 |
+
kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False)
|
58 |
+
self.bn = torch.nn.BatchNorm2d(out_channels * 2)
|
59 |
+
self.relu = torch.nn.ReLU(inplace=True)
|
60 |
+
|
61 |
+
# squeeze and excitation block
|
62 |
+
self.use_se = use_se
|
63 |
+
if use_se:
|
64 |
+
if se_kwargs is None:
|
65 |
+
se_kwargs = {}
|
66 |
+
self.se = SELayer(self.conv_layer.in_channels, **se_kwargs)
|
67 |
+
|
68 |
+
self.spatial_scale_factor = spatial_scale_factor
|
69 |
+
self.spatial_scale_mode = spatial_scale_mode
|
70 |
+
self.spectral_pos_encoding = spectral_pos_encoding
|
71 |
+
self.ffc3d = ffc3d
|
72 |
+
self.fft_norm = fft_norm
|
73 |
+
|
74 |
+
def forward(self, x):
|
75 |
+
batch = x.shape[0]
|
76 |
+
|
77 |
+
if self.spatial_scale_factor is not None:
|
78 |
+
orig_size = x.shape[-2:]
|
79 |
+
x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode, align_corners=False)
|
80 |
+
|
81 |
+
r_size = x.size()
|
82 |
+
# (batch, c, h, w/2+1, 2)
|
83 |
+
fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
|
84 |
+
ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
|
85 |
+
ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
|
86 |
+
ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
|
87 |
+
ffted = ffted.view((batch, -1,) + ffted.size()[3:])
|
88 |
+
|
89 |
+
if self.spectral_pos_encoding:
|
90 |
+
height, width = ffted.shape[-2:]
|
91 |
+
coords_vert = torch.linspace(0, 1, height)[None, None, :, None].expand(batch, 1, height, width).to(ffted)
|
92 |
+
coords_hor = torch.linspace(0, 1, width)[None, None, None, :].expand(batch, 1, height, width).to(ffted)
|
93 |
+
ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1)
|
94 |
+
|
95 |
+
if self.use_se:
|
96 |
+
ffted = self.se(ffted)
|
97 |
+
|
98 |
+
ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1)
|
99 |
+
ffted = self.relu(self.bn(ffted))
|
100 |
+
|
101 |
+
ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
|
102 |
+
0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2)
|
103 |
+
ffted = torch.complex(ffted[..., 0], ffted[..., 1])
|
104 |
+
|
105 |
+
ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
|
106 |
+
output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm)
|
107 |
+
|
108 |
+
if self.spatial_scale_factor is not None:
|
109 |
+
output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False)
|
110 |
+
|
111 |
+
return output
|
112 |
+
|
113 |
+
|
114 |
+
class SpectralTransform(nn.Module):
|
115 |
+
|
116 |
+
def __init__(self, in_channels, out_channels, stride=1, groups=1, enable_lfu=True, **fu_kwargs):
|
117 |
+
# bn_layer not used
|
118 |
+
super(SpectralTransform, self).__init__()
|
119 |
+
self.enable_lfu = enable_lfu
|
120 |
+
if stride == 2:
|
121 |
+
self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
|
122 |
+
else:
|
123 |
+
self.downsample = nn.Identity()
|
124 |
+
|
125 |
+
self.stride = stride
|
126 |
+
self.conv1 = nn.Sequential(
|
127 |
+
nn.Conv2d(in_channels, out_channels //
|
128 |
+
2, kernel_size=1, groups=groups, bias=False),
|
129 |
+
nn.BatchNorm2d(out_channels // 2),
|
130 |
+
nn.ReLU(inplace=True)
|
131 |
+
)
|
132 |
+
self.fu = FourierUnit(
|
133 |
+
out_channels // 2, out_channels // 2, groups, **fu_kwargs)
|
134 |
+
if self.enable_lfu:
|
135 |
+
self.lfu = FourierUnit(
|
136 |
+
out_channels // 2, out_channels // 2, groups)
|
137 |
+
self.conv2 = torch.nn.Conv2d(
|
138 |
+
out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False)
|
139 |
+
|
140 |
+
def forward(self, x):
|
141 |
+
|
142 |
+
x = self.downsample(x)
|
143 |
+
x = self.conv1(x)
|
144 |
+
output = self.fu(x)
|
145 |
+
|
146 |
+
if self.enable_lfu:
|
147 |
+
n, c, h, w = x.shape
|
148 |
+
split_no = 2
|
149 |
+
split_s = h // split_no
|
150 |
+
xs = torch.cat(torch.split(
|
151 |
+
x[:, :c // 4], split_s, dim=-2), dim=1).contiguous()
|
152 |
+
xs = torch.cat(torch.split(xs, split_s, dim=-1),
|
153 |
+
dim=1).contiguous()
|
154 |
+
xs = self.lfu(xs)
|
155 |
+
xs = xs.repeat(1, 1, split_no, split_no).contiguous()
|
156 |
+
else:
|
157 |
+
xs = 0
|
158 |
+
|
159 |
+
output = self.conv2(x + output + xs)
|
160 |
+
|
161 |
+
return output
|
162 |
+
|
163 |
+
|
164 |
+
class FFC(nn.Module):
|
165 |
+
|
166 |
+
def __init__(self, in_channels, out_channels, kernel_size,
|
167 |
+
ratio_gin, ratio_gout, stride=1, padding=0,
|
168 |
+
dilation=1, groups=1, bias=False, enable_lfu=True,
|
169 |
+
padding_type='reflect', gated=False, **spectral_kwargs):
|
170 |
+
super(FFC, self).__init__()
|
171 |
+
|
172 |
+
assert stride == 1 or stride == 2, "Stride should be 1 or 2."
|
173 |
+
self.stride = stride
|
174 |
+
|
175 |
+
in_cg = int(in_channels * ratio_gin)
|
176 |
+
in_cl = in_channels - in_cg
|
177 |
+
out_cg = int(out_channels * ratio_gout)
|
178 |
+
out_cl = out_channels - out_cg
|
179 |
+
#groups_g = 1 if groups == 1 else int(groups * ratio_gout)
|
180 |
+
#groups_l = 1 if groups == 1 else groups - groups_g
|
181 |
+
|
182 |
+
self.ratio_gin = ratio_gin
|
183 |
+
self.ratio_gout = ratio_gout
|
184 |
+
self.global_in_num = in_cg
|
185 |
+
|
186 |
+
module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d
|
187 |
+
self.convl2l = module(in_cl, out_cl, kernel_size,
|
188 |
+
stride, padding, dilation, groups, bias, padding_mode=padding_type)
|
189 |
+
module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d
|
190 |
+
self.convl2g = module(in_cl, out_cg, kernel_size,
|
191 |
+
stride, padding, dilation, groups, bias, padding_mode=padding_type)
|
192 |
+
module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d
|
193 |
+
self.convg2l = module(in_cg, out_cl, kernel_size,
|
194 |
+
stride, padding, dilation, groups, bias, padding_mode=padding_type)
|
195 |
+
module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform
|
196 |
+
self.convg2g = module(
|
197 |
+
in_cg, out_cg, stride, 1 if groups == 1 else groups // 2, enable_lfu, **spectral_kwargs)
|
198 |
+
|
199 |
+
self.gated = gated
|
200 |
+
module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d
|
201 |
+
self.gate = module(in_channels, 2, 1)
|
202 |
+
|
203 |
+
def forward(self, x):
|
204 |
+
x_l, x_g = x if type(x) is tuple else (x, 0)
|
205 |
+
out_xl, out_xg = 0, 0
|
206 |
+
|
207 |
+
if self.gated:
|
208 |
+
total_input_parts = [x_l]
|
209 |
+
if torch.is_tensor(x_g):
|
210 |
+
total_input_parts.append(x_g)
|
211 |
+
total_input = torch.cat(total_input_parts, dim=1)
|
212 |
+
|
213 |
+
gates = torch.sigmoid(self.gate(total_input))
|
214 |
+
g2l_gate, l2g_gate = gates.chunk(2, dim=1)
|
215 |
+
else:
|
216 |
+
g2l_gate, l2g_gate = 1, 1
|
217 |
+
|
218 |
+
if self.ratio_gout != 1:
|
219 |
+
out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate
|
220 |
+
if self.ratio_gout != 0:
|
221 |
+
out_xg = self.convl2g(x_l) * l2g_gate + self.convg2g(x_g)
|
222 |
+
|
223 |
+
return out_xl, out_xg
|
224 |
+
|
225 |
+
|
226 |
+
class FFC_BN_ACT(nn.Module):
|
227 |
+
|
228 |
+
def __init__(self, in_channels, out_channels,
|
229 |
+
kernel_size, ratio_gin, ratio_gout,
|
230 |
+
stride=1, padding=0, dilation=1, groups=1, bias=False,
|
231 |
+
norm_layer=nn.BatchNorm2d, activation_layer=nn.Identity,
|
232 |
+
padding_type='reflect',
|
233 |
+
enable_lfu=True, **kwargs):
|
234 |
+
super(FFC_BN_ACT, self).__init__()
|
235 |
+
self.ffc = FFC(in_channels, out_channels, kernel_size,
|
236 |
+
ratio_gin, ratio_gout, stride, padding, dilation,
|
237 |
+
groups, bias, enable_lfu, padding_type=padding_type, **kwargs)
|
238 |
+
lnorm = nn.Identity if ratio_gout == 1 else norm_layer
|
239 |
+
gnorm = nn.Identity if ratio_gout == 0 else norm_layer
|
240 |
+
global_channels = int(out_channels * ratio_gout)
|
241 |
+
self.bn_l = lnorm(out_channels - global_channels)
|
242 |
+
self.bn_g = gnorm(global_channels)
|
243 |
+
|
244 |
+
lact = nn.Identity if ratio_gout == 1 else activation_layer
|
245 |
+
gact = nn.Identity if ratio_gout == 0 else activation_layer
|
246 |
+
self.act_l = lact(inplace=True)
|
247 |
+
self.act_g = gact(inplace=True)
|
248 |
+
|
249 |
+
def forward(self, x):
|
250 |
+
x_l, x_g = self.ffc(x)
|
251 |
+
x_l = self.act_l(self.bn_l(x_l))
|
252 |
+
x_g = self.act_g(self.bn_g(x_g))
|
253 |
+
return x_l, x_g
|
254 |
+
|
255 |
+
|
256 |
+
class FFCResnetBlock(nn.Module):
|
257 |
+
def __init__(self, dim, padding_type, norm_layer, activation_layer=nn.ReLU, dilation=1,
|
258 |
+
spatial_transform_kwargs=None, inline=False, **conv_kwargs):
|
259 |
+
super().__init__()
|
260 |
+
self.conv1 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation,
|
261 |
+
norm_layer=norm_layer,
|
262 |
+
activation_layer=activation_layer,
|
263 |
+
padding_type=padding_type,
|
264 |
+
**conv_kwargs)
|
265 |
+
self.conv2 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation,
|
266 |
+
norm_layer=norm_layer,
|
267 |
+
activation_layer=activation_layer,
|
268 |
+
padding_type=padding_type,
|
269 |
+
**conv_kwargs)
|
270 |
+
if spatial_transform_kwargs is not None:
|
271 |
+
self.conv1 = LearnableSpatialTransformWrapper(self.conv1, **spatial_transform_kwargs)
|
272 |
+
self.conv2 = LearnableSpatialTransformWrapper(self.conv2, **spatial_transform_kwargs)
|
273 |
+
self.inline = inline
|
274 |
+
|
275 |
+
def forward(self, x):
|
276 |
+
if self.inline:
|
277 |
+
x_l, x_g = x[:, :-self.conv1.ffc.global_in_num], x[:, -self.conv1.ffc.global_in_num:]
|
278 |
+
else:
|
279 |
+
x_l, x_g = x if type(x) is tuple else (x, 0)
|
280 |
+
|
281 |
+
id_l, id_g = x_l, x_g
|
282 |
+
|
283 |
+
x_l, x_g = self.conv1((x_l, x_g))
|
284 |
+
x_l, x_g = self.conv2((x_l, x_g))
|
285 |
+
|
286 |
+
x_l, x_g = id_l + x_l, id_g + x_g
|
287 |
+
out = x_l, x_g
|
288 |
+
if self.inline:
|
289 |
+
out = torch.cat(out, dim=1)
|
290 |
+
return out
|
291 |
+
|
292 |
+
|
293 |
+
class ConcatTupleLayer(nn.Module):
|
294 |
+
def forward(self, x):
|
295 |
+
assert isinstance(x, tuple)
|
296 |
+
x_l, x_g = x
|
297 |
+
assert torch.is_tensor(x_l) or torch.is_tensor(x_g)
|
298 |
+
if not torch.is_tensor(x_g):
|
299 |
+
return x_l
|
300 |
+
return torch.cat(x, dim=1)
|
301 |
+
|
302 |
+
|
303 |
+
class FFCResNetGenerator(nn.Module):
|
304 |
+
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
|
305 |
+
padding_type='reflect', activation_layer=nn.ReLU,
|
306 |
+
up_norm_layer=nn.BatchNorm2d, up_activation=nn.ReLU(True),
|
307 |
+
init_conv_kwargs={}, downsample_conv_kwargs={}, resnet_conv_kwargs={},
|
308 |
+
spatial_transform_layers=None, spatial_transform_kwargs={},
|
309 |
+
add_out_act=True, max_features=1024, out_ffc=False, out_ffc_kwargs={}):
|
310 |
+
assert (n_blocks >= 0)
|
311 |
+
super().__init__()
|
312 |
+
|
313 |
+
model = [nn.ReflectionPad2d(3),
|
314 |
+
FFC_BN_ACT(input_nc, ngf, kernel_size=7, padding=0, norm_layer=norm_layer,
|
315 |
+
activation_layer=activation_layer, **init_conv_kwargs)]
|
316 |
+
|
317 |
+
### downsample
|
318 |
+
for i in range(n_downsampling):
|
319 |
+
mult = 2 ** i
|
320 |
+
if i == n_downsampling - 1:
|
321 |
+
cur_conv_kwargs = dict(downsample_conv_kwargs)
|
322 |
+
cur_conv_kwargs['ratio_gout'] = resnet_conv_kwargs.get('ratio_gin', 0)
|
323 |
+
else:
|
324 |
+
cur_conv_kwargs = downsample_conv_kwargs
|
325 |
+
model += [FFC_BN_ACT(min(max_features, ngf * mult),
|
326 |
+
min(max_features, ngf * mult * 2),
|
327 |
+
kernel_size=3, stride=2, padding=1,
|
328 |
+
norm_layer=norm_layer,
|
329 |
+
activation_layer=activation_layer,
|
330 |
+
**cur_conv_kwargs)]
|
331 |
+
|
332 |
+
mult = 2 ** n_downsampling
|
333 |
+
feats_num_bottleneck = min(max_features, ngf * mult)
|
334 |
+
|
335 |
+
### resnet blocks
|
336 |
+
for i in range(n_blocks):
|
337 |
+
cur_resblock = FFCResnetBlock(feats_num_bottleneck, padding_type=padding_type, activation_layer=activation_layer,
|
338 |
+
norm_layer=norm_layer, **resnet_conv_kwargs)
|
339 |
+
if spatial_transform_layers is not None and i in spatial_transform_layers:
|
340 |
+
cur_resblock = LearnableSpatialTransformWrapper(cur_resblock, **spatial_transform_kwargs)
|
341 |
+
model += [cur_resblock]
|
342 |
+
|
343 |
+
model += [ConcatTupleLayer()]
|
344 |
+
|
345 |
+
### upsample
|
346 |
+
for i in range(n_downsampling):
|
347 |
+
mult = 2 ** (n_downsampling - i)
|
348 |
+
model += [nn.ConvTranspose2d(min(max_features, ngf * mult),
|
349 |
+
min(max_features, int(ngf * mult / 2)),
|
350 |
+
kernel_size=3, stride=2, padding=1, output_padding=1),
|
351 |
+
up_norm_layer(min(max_features, int(ngf * mult / 2))),
|
352 |
+
up_activation]
|
353 |
+
|
354 |
+
if out_ffc:
|
355 |
+
model += [FFCResnetBlock(ngf, padding_type=padding_type, activation_layer=activation_layer,
|
356 |
+
norm_layer=norm_layer, inline=True, **out_ffc_kwargs)]
|
357 |
+
|
358 |
+
model += [nn.ReflectionPad2d(3),
|
359 |
+
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
|
360 |
+
if add_out_act:
|
361 |
+
model.append(get_activation('tanh' if add_out_act is True else add_out_act))
|
362 |
+
self.model = nn.Sequential(*model)
|
363 |
+
|
364 |
+
def forward(self, input):
|
365 |
+
return self.model(input)
|
366 |
+
|
367 |
+
|
inpaint/saicinpainting/training/modules/multidilated_conv.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import random
|
4 |
+
from saicinpainting.training.modules.depthwise_sep_conv import DepthWiseSeperableConv
|
5 |
+
|
6 |
+
class MultidilatedConv(nn.Module):
|
7 |
+
def __init__(self, in_dim, out_dim, kernel_size, dilation_num=3, comb_mode='sum', equal_dim=True,
|
8 |
+
shared_weights=False, padding=1, min_dilation=1, shuffle_in_channels=False, use_depthwise=False, **kwargs):
|
9 |
+
super().__init__()
|
10 |
+
convs = []
|
11 |
+
self.equal_dim = equal_dim
|
12 |
+
assert comb_mode in ('cat_out', 'sum', 'cat_in', 'cat_both'), comb_mode
|
13 |
+
if comb_mode in ('cat_out', 'cat_both'):
|
14 |
+
self.cat_out = True
|
15 |
+
if equal_dim:
|
16 |
+
assert out_dim % dilation_num == 0
|
17 |
+
out_dims = [out_dim // dilation_num] * dilation_num
|
18 |
+
self.index = sum([[i + j * (out_dims[0]) for j in range(dilation_num)] for i in range(out_dims[0])], [])
|
19 |
+
else:
|
20 |
+
out_dims = [out_dim // 2 ** (i + 1) for i in range(dilation_num - 1)]
|
21 |
+
out_dims.append(out_dim - sum(out_dims))
|
22 |
+
index = []
|
23 |
+
starts = [0] + out_dims[:-1]
|
24 |
+
lengths = [out_dims[i] // out_dims[-1] for i in range(dilation_num)]
|
25 |
+
for i in range(out_dims[-1]):
|
26 |
+
for j in range(dilation_num):
|
27 |
+
index += list(range(starts[j], starts[j] + lengths[j]))
|
28 |
+
starts[j] += lengths[j]
|
29 |
+
self.index = index
|
30 |
+
assert(len(index) == out_dim)
|
31 |
+
self.out_dims = out_dims
|
32 |
+
else:
|
33 |
+
self.cat_out = False
|
34 |
+
self.out_dims = [out_dim] * dilation_num
|
35 |
+
|
36 |
+
if comb_mode in ('cat_in', 'cat_both'):
|
37 |
+
if equal_dim:
|
38 |
+
assert in_dim % dilation_num == 0
|
39 |
+
in_dims = [in_dim // dilation_num] * dilation_num
|
40 |
+
else:
|
41 |
+
in_dims = [in_dim // 2 ** (i + 1) for i in range(dilation_num - 1)]
|
42 |
+
in_dims.append(in_dim - sum(in_dims))
|
43 |
+
self.in_dims = in_dims
|
44 |
+
self.cat_in = True
|
45 |
+
else:
|
46 |
+
self.cat_in = False
|
47 |
+
self.in_dims = [in_dim] * dilation_num
|
48 |
+
|
49 |
+
conv_type = DepthWiseSeperableConv if use_depthwise else nn.Conv2d
|
50 |
+
dilation = min_dilation
|
51 |
+
for i in range(dilation_num):
|
52 |
+
if isinstance(padding, int):
|
53 |
+
cur_padding = padding * dilation
|
54 |
+
else:
|
55 |
+
cur_padding = padding[i]
|
56 |
+
convs.append(conv_type(
|
57 |
+
self.in_dims[i], self.out_dims[i], kernel_size, padding=cur_padding, dilation=dilation, **kwargs
|
58 |
+
))
|
59 |
+
if i > 0 and shared_weights:
|
60 |
+
convs[-1].weight = convs[0].weight
|
61 |
+
convs[-1].bias = convs[0].bias
|
62 |
+
dilation *= 2
|
63 |
+
self.convs = nn.ModuleList(convs)
|
64 |
+
|
65 |
+
self.shuffle_in_channels = shuffle_in_channels
|
66 |
+
if self.shuffle_in_channels:
|
67 |
+
# shuffle list as shuffling of tensors is nondeterministic
|
68 |
+
in_channels_permute = list(range(in_dim))
|
69 |
+
random.shuffle(in_channels_permute)
|
70 |
+
# save as buffer so it is saved and loaded with checkpoint
|
71 |
+
self.register_buffer('in_channels_permute', torch.tensor(in_channels_permute))
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
if self.shuffle_in_channels:
|
75 |
+
x = x[:, self.in_channels_permute]
|
76 |
+
|
77 |
+
outs = []
|
78 |
+
if self.cat_in:
|
79 |
+
if self.equal_dim:
|
80 |
+
x = x.chunk(len(self.convs), dim=1)
|
81 |
+
else:
|
82 |
+
new_x = []
|
83 |
+
start = 0
|
84 |
+
for dim in self.in_dims:
|
85 |
+
new_x.append(x[:, start:start+dim])
|
86 |
+
start += dim
|
87 |
+
x = new_x
|
88 |
+
for i, conv in enumerate(self.convs):
|
89 |
+
if self.cat_in:
|
90 |
+
input = x[i]
|
91 |
+
else:
|
92 |
+
input = x
|
93 |
+
outs.append(conv(input))
|
94 |
+
if self.cat_out:
|
95 |
+
out = torch.cat(outs, dim=1)[:, self.index]
|
96 |
+
else:
|
97 |
+
out = sum(outs)
|
98 |
+
return out
|
inpaint/saicinpainting/training/modules/multiscale.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple, Union, Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from saicinpainting.training.modules.base import get_conv_block_ctor, get_activation
|
8 |
+
from saicinpainting.training.modules.pix2pixhd import ResnetBlock
|
9 |
+
|
10 |
+
|
11 |
+
class ResNetHead(nn.Module):
|
12 |
+
def __init__(self, input_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
|
13 |
+
padding_type='reflect', conv_kind='default', activation=nn.ReLU(True)):
|
14 |
+
assert (n_blocks >= 0)
|
15 |
+
super(ResNetHead, self).__init__()
|
16 |
+
|
17 |
+
conv_layer = get_conv_block_ctor(conv_kind)
|
18 |
+
|
19 |
+
model = [nn.ReflectionPad2d(3),
|
20 |
+
conv_layer(input_nc, ngf, kernel_size=7, padding=0),
|
21 |
+
norm_layer(ngf),
|
22 |
+
activation]
|
23 |
+
|
24 |
+
### downsample
|
25 |
+
for i in range(n_downsampling):
|
26 |
+
mult = 2 ** i
|
27 |
+
model += [conv_layer(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
|
28 |
+
norm_layer(ngf * mult * 2),
|
29 |
+
activation]
|
30 |
+
|
31 |
+
mult = 2 ** n_downsampling
|
32 |
+
|
33 |
+
### resnet blocks
|
34 |
+
for i in range(n_blocks):
|
35 |
+
model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
|
36 |
+
conv_kind=conv_kind)]
|
37 |
+
|
38 |
+
self.model = nn.Sequential(*model)
|
39 |
+
|
40 |
+
def forward(self, input):
|
41 |
+
return self.model(input)
|
42 |
+
|
43 |
+
|
44 |
+
class ResNetTail(nn.Module):
|
45 |
+
def __init__(self, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
|
46 |
+
padding_type='reflect', conv_kind='default', activation=nn.ReLU(True),
|
47 |
+
up_norm_layer=nn.BatchNorm2d, up_activation=nn.ReLU(True), add_out_act=False, out_extra_layers_n=0,
|
48 |
+
add_in_proj=None):
|
49 |
+
assert (n_blocks >= 0)
|
50 |
+
super(ResNetTail, self).__init__()
|
51 |
+
|
52 |
+
mult = 2 ** n_downsampling
|
53 |
+
|
54 |
+
model = []
|
55 |
+
|
56 |
+
if add_in_proj is not None:
|
57 |
+
model.append(nn.Conv2d(add_in_proj, ngf * mult, kernel_size=1))
|
58 |
+
|
59 |
+
### resnet blocks
|
60 |
+
for i in range(n_blocks):
|
61 |
+
model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
|
62 |
+
conv_kind=conv_kind)]
|
63 |
+
|
64 |
+
### upsample
|
65 |
+
for i in range(n_downsampling):
|
66 |
+
mult = 2 ** (n_downsampling - i)
|
67 |
+
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1,
|
68 |
+
output_padding=1),
|
69 |
+
up_norm_layer(int(ngf * mult / 2)),
|
70 |
+
up_activation]
|
71 |
+
self.model = nn.Sequential(*model)
|
72 |
+
|
73 |
+
out_layers = []
|
74 |
+
for _ in range(out_extra_layers_n):
|
75 |
+
out_layers += [nn.Conv2d(ngf, ngf, kernel_size=1, padding=0),
|
76 |
+
up_norm_layer(ngf),
|
77 |
+
up_activation]
|
78 |
+
out_layers += [nn.ReflectionPad2d(3),
|
79 |
+
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
|
80 |
+
|
81 |
+
if add_out_act:
|
82 |
+
out_layers.append(get_activation('tanh' if add_out_act is True else add_out_act))
|
83 |
+
|
84 |
+
self.out_proj = nn.Sequential(*out_layers)
|
85 |
+
|
86 |
+
def forward(self, input, return_last_act=False):
|
87 |
+
features = self.model(input)
|
88 |
+
out = self.out_proj(features)
|
89 |
+
if return_last_act:
|
90 |
+
return out, features
|
91 |
+
else:
|
92 |
+
return out
|
93 |
+
|
94 |
+
|
95 |
+
class MultiscaleResNet(nn.Module):
|
96 |
+
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=2, n_blocks_head=2, n_blocks_tail=6, n_scales=3,
|
97 |
+
norm_layer=nn.BatchNorm2d, padding_type='reflect', conv_kind='default', activation=nn.ReLU(True),
|
98 |
+
up_norm_layer=nn.BatchNorm2d, up_activation=nn.ReLU(True), add_out_act=False, out_extra_layers_n=0,
|
99 |
+
out_cumulative=False, return_only_hr=False):
|
100 |
+
super().__init__()
|
101 |
+
|
102 |
+
self.heads = nn.ModuleList([ResNetHead(input_nc, ngf=ngf, n_downsampling=n_downsampling,
|
103 |
+
n_blocks=n_blocks_head, norm_layer=norm_layer, padding_type=padding_type,
|
104 |
+
conv_kind=conv_kind, activation=activation)
|
105 |
+
for i in range(n_scales)])
|
106 |
+
tail_in_feats = ngf * (2 ** n_downsampling) + ngf
|
107 |
+
self.tails = nn.ModuleList([ResNetTail(output_nc,
|
108 |
+
ngf=ngf, n_downsampling=n_downsampling,
|
109 |
+
n_blocks=n_blocks_tail, norm_layer=norm_layer, padding_type=padding_type,
|
110 |
+
conv_kind=conv_kind, activation=activation, up_norm_layer=up_norm_layer,
|
111 |
+
up_activation=up_activation, add_out_act=add_out_act,
|
112 |
+
out_extra_layers_n=out_extra_layers_n,
|
113 |
+
add_in_proj=None if (i == n_scales - 1) else tail_in_feats)
|
114 |
+
for i in range(n_scales)])
|
115 |
+
|
116 |
+
self.out_cumulative = out_cumulative
|
117 |
+
self.return_only_hr = return_only_hr
|
118 |
+
|
119 |
+
@property
|
120 |
+
def num_scales(self):
|
121 |
+
return len(self.heads)
|
122 |
+
|
123 |
+
def forward(self, ms_inputs: List[torch.Tensor], smallest_scales_num: Optional[int] = None) \
|
124 |
+
-> Union[torch.Tensor, List[torch.Tensor]]:
|
125 |
+
"""
|
126 |
+
:param ms_inputs: List of inputs of different resolutions from HR to LR
|
127 |
+
:param smallest_scales_num: int or None, number of smallest scales to take at input
|
128 |
+
:return: Depending on return_only_hr:
|
129 |
+
True: Only the most HR output
|
130 |
+
False: List of outputs of different resolutions from HR to LR
|
131 |
+
"""
|
132 |
+
if smallest_scales_num is None:
|
133 |
+
assert len(self.heads) == len(ms_inputs), (len(self.heads), len(ms_inputs), smallest_scales_num)
|
134 |
+
smallest_scales_num = len(self.heads)
|
135 |
+
else:
|
136 |
+
assert smallest_scales_num == len(ms_inputs) <= len(self.heads), (len(self.heads), len(ms_inputs), smallest_scales_num)
|
137 |
+
|
138 |
+
cur_heads = self.heads[-smallest_scales_num:]
|
139 |
+
ms_features = [cur_head(cur_inp) for cur_head, cur_inp in zip(cur_heads, ms_inputs)]
|
140 |
+
|
141 |
+
all_outputs = []
|
142 |
+
prev_tail_features = None
|
143 |
+
for i in range(len(ms_features)):
|
144 |
+
scale_i = -i - 1
|
145 |
+
|
146 |
+
cur_tail_input = ms_features[-i - 1]
|
147 |
+
if prev_tail_features is not None:
|
148 |
+
if prev_tail_features.shape != cur_tail_input.shape:
|
149 |
+
prev_tail_features = F.interpolate(prev_tail_features, size=cur_tail_input.shape[2:],
|
150 |
+
mode='bilinear', align_corners=False)
|
151 |
+
cur_tail_input = torch.cat((cur_tail_input, prev_tail_features), dim=1)
|
152 |
+
|
153 |
+
cur_out, cur_tail_feats = self.tails[scale_i](cur_tail_input, return_last_act=True)
|
154 |
+
|
155 |
+
prev_tail_features = cur_tail_feats
|
156 |
+
all_outputs.append(cur_out)
|
157 |
+
|
158 |
+
if self.out_cumulative:
|
159 |
+
all_outputs_cum = [all_outputs[0]]
|
160 |
+
for i in range(1, len(ms_features)):
|
161 |
+
cur_out = all_outputs[i]
|
162 |
+
cur_out_cum = cur_out + F.interpolate(all_outputs_cum[-1], size=cur_out.shape[2:],
|
163 |
+
mode='bilinear', align_corners=False)
|
164 |
+
all_outputs_cum.append(cur_out_cum)
|
165 |
+
all_outputs = all_outputs_cum
|
166 |
+
|
167 |
+
if self.return_only_hr:
|
168 |
+
return all_outputs[-1]
|
169 |
+
else:
|
170 |
+
return all_outputs[::-1]
|
171 |
+
|
172 |
+
|
173 |
+
class MultiscaleDiscriminatorSimple(nn.Module):
|
174 |
+
def __init__(self, ms_impl):
|
175 |
+
super().__init__()
|
176 |
+
self.ms_impl = nn.ModuleList(ms_impl)
|
177 |
+
|
178 |
+
@property
|
179 |
+
def num_scales(self):
|
180 |
+
return len(self.ms_impl)
|
181 |
+
|
182 |
+
def forward(self, ms_inputs: List[torch.Tensor], smallest_scales_num: Optional[int] = None) \
|
183 |
+
-> List[Tuple[torch.Tensor, List[torch.Tensor]]]:
|
184 |
+
"""
|
185 |
+
:param ms_inputs: List of inputs of different resolutions from HR to LR
|
186 |
+
:param smallest_scales_num: int or None, number of smallest scales to take at input
|
187 |
+
:return: List of pairs (prediction, features) for different resolutions from HR to LR
|
188 |
+
"""
|
189 |
+
if smallest_scales_num is None:
|
190 |
+
assert len(self.ms_impl) == len(ms_inputs), (len(self.ms_impl), len(ms_inputs), smallest_scales_num)
|
191 |
+
smallest_scales_num = len(self.heads)
|
192 |
+
else:
|
193 |
+
assert smallest_scales_num == len(ms_inputs) <= len(self.ms_impl), \
|
194 |
+
(len(self.ms_impl), len(ms_inputs), smallest_scales_num)
|
195 |
+
|
196 |
+
return [cur_discr(cur_input) for cur_discr, cur_input in zip(self.ms_impl[-smallest_scales_num:], ms_inputs)]
|
197 |
+
|
198 |
+
|
199 |
+
class SingleToMultiScaleInputMixin:
|
200 |
+
def forward(self, x: torch.Tensor) -> List:
|
201 |
+
orig_height, orig_width = x.shape[2:]
|
202 |
+
factors = [2 ** i for i in range(self.num_scales)]
|
203 |
+
ms_inputs = [F.interpolate(x, size=(orig_height // f, orig_width // f), mode='bilinear', align_corners=False)
|
204 |
+
for f in factors]
|
205 |
+
return super().forward(ms_inputs)
|
206 |
+
|
207 |
+
|
208 |
+
class GeneratorMultiToSingleOutputMixin:
|
209 |
+
def forward(self, x):
|
210 |
+
return super().forward(x)[0]
|
211 |
+
|
212 |
+
|
213 |
+
class DiscriminatorMultiToSingleOutputMixin:
|
214 |
+
def forward(self, x):
|
215 |
+
out_feat_tuples = super().forward(x)
|
216 |
+
return out_feat_tuples[0][0], [f for _, flist in out_feat_tuples for f in flist]
|
217 |
+
|
218 |
+
|
219 |
+
class DiscriminatorMultiToSingleOutputStackedMixin:
|
220 |
+
def __init__(self, *args, return_feats_only_levels=None, **kwargs):
|
221 |
+
super().__init__(*args, **kwargs)
|
222 |
+
self.return_feats_only_levels = return_feats_only_levels
|
223 |
+
|
224 |
+
def forward(self, x):
|
225 |
+
out_feat_tuples = super().forward(x)
|
226 |
+
outs = [out for out, _ in out_feat_tuples]
|
227 |
+
scaled_outs = [outs[0]] + [F.interpolate(cur_out, size=outs[0].shape[-2:],
|
228 |
+
mode='bilinear', align_corners=False)
|
229 |
+
for cur_out in outs[1:]]
|
230 |
+
out = torch.cat(scaled_outs, dim=1)
|
231 |
+
if self.return_feats_only_levels is not None:
|
232 |
+
feat_lists = [out_feat_tuples[i][1] for i in self.return_feats_only_levels]
|
233 |
+
else:
|
234 |
+
feat_lists = [flist for _, flist in out_feat_tuples]
|
235 |
+
feats = [f for flist in feat_lists for f in flist]
|
236 |
+
return out, feats
|
237 |
+
|
238 |
+
|
239 |
+
class MultiscaleDiscrSingleInput(SingleToMultiScaleInputMixin, DiscriminatorMultiToSingleOutputStackedMixin, MultiscaleDiscriminatorSimple):
|
240 |
+
pass
|
241 |
+
|
242 |
+
|
243 |
+
class MultiscaleResNetSingle(GeneratorMultiToSingleOutputMixin, SingleToMultiScaleInputMixin, MultiscaleResNet):
|
244 |
+
pass
|
inpaint/saicinpainting/training/modules/pix2pixhd.py
ADDED
@@ -0,0 +1,669 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# original: https://github.com/NVIDIA/pix2pixHD/blob/master/models/networks.py
|
2 |
+
import collections
|
3 |
+
from functools import partial
|
4 |
+
import functools
|
5 |
+
import logging
|
6 |
+
from collections import defaultdict
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
from saicinpainting.training.modules.base import BaseDiscriminator, deconv_factory, get_conv_block_ctor, get_norm_layer, get_activation
|
12 |
+
from saicinpainting.training.modules.ffc import FFCResnetBlock
|
13 |
+
from saicinpainting.training.modules.multidilated_conv import MultidilatedConv
|
14 |
+
|
15 |
+
class DotDict(defaultdict):
|
16 |
+
# https://stackoverflow.com/questions/2352181/how-to-use-a-dot-to-access-members-of-dictionary
|
17 |
+
"""dot.notation access to dictionary attributes"""
|
18 |
+
__getattr__ = defaultdict.get
|
19 |
+
__setattr__ = defaultdict.__setitem__
|
20 |
+
__delattr__ = defaultdict.__delitem__
|
21 |
+
|
22 |
+
class Identity(nn.Module):
|
23 |
+
def __init__(self):
|
24 |
+
super().__init__()
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
return x
|
28 |
+
|
29 |
+
|
30 |
+
class ResnetBlock(nn.Module):
|
31 |
+
def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False, conv_kind='default',
|
32 |
+
dilation=1, in_dim=None, groups=1, second_dilation=None):
|
33 |
+
super(ResnetBlock, self).__init__()
|
34 |
+
self.in_dim = in_dim
|
35 |
+
self.dim = dim
|
36 |
+
if second_dilation is None:
|
37 |
+
second_dilation = dilation
|
38 |
+
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout,
|
39 |
+
conv_kind=conv_kind, dilation=dilation, in_dim=in_dim, groups=groups,
|
40 |
+
second_dilation=second_dilation)
|
41 |
+
|
42 |
+
if self.in_dim is not None:
|
43 |
+
self.input_conv = nn.Conv2d(in_dim, dim, 1)
|
44 |
+
|
45 |
+
self.out_channnels = dim
|
46 |
+
|
47 |
+
def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout, conv_kind='default',
|
48 |
+
dilation=1, in_dim=None, groups=1, second_dilation=1):
|
49 |
+
conv_layer = get_conv_block_ctor(conv_kind)
|
50 |
+
|
51 |
+
conv_block = []
|
52 |
+
p = 0
|
53 |
+
if padding_type == 'reflect':
|
54 |
+
conv_block += [nn.ReflectionPad2d(dilation)]
|
55 |
+
elif padding_type == 'replicate':
|
56 |
+
conv_block += [nn.ReplicationPad2d(dilation)]
|
57 |
+
elif padding_type == 'zero':
|
58 |
+
p = dilation
|
59 |
+
else:
|
60 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
61 |
+
|
62 |
+
if in_dim is None:
|
63 |
+
in_dim = dim
|
64 |
+
|
65 |
+
conv_block += [conv_layer(in_dim, dim, kernel_size=3, padding=p, dilation=dilation),
|
66 |
+
norm_layer(dim),
|
67 |
+
activation]
|
68 |
+
if use_dropout:
|
69 |
+
conv_block += [nn.Dropout(0.5)]
|
70 |
+
|
71 |
+
p = 0
|
72 |
+
if padding_type == 'reflect':
|
73 |
+
conv_block += [nn.ReflectionPad2d(second_dilation)]
|
74 |
+
elif padding_type == 'replicate':
|
75 |
+
conv_block += [nn.ReplicationPad2d(second_dilation)]
|
76 |
+
elif padding_type == 'zero':
|
77 |
+
p = second_dilation
|
78 |
+
else:
|
79 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
80 |
+
conv_block += [conv_layer(dim, dim, kernel_size=3, padding=p, dilation=second_dilation, groups=groups),
|
81 |
+
norm_layer(dim)]
|
82 |
+
|
83 |
+
return nn.Sequential(*conv_block)
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
x_before = x
|
87 |
+
if self.in_dim is not None:
|
88 |
+
x = self.input_conv(x)
|
89 |
+
out = x + self.conv_block(x_before)
|
90 |
+
return out
|
91 |
+
|
92 |
+
class ResnetBlock5x5(nn.Module):
|
93 |
+
def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False, conv_kind='default',
|
94 |
+
dilation=1, in_dim=None, groups=1, second_dilation=None):
|
95 |
+
super(ResnetBlock5x5, self).__init__()
|
96 |
+
self.in_dim = in_dim
|
97 |
+
self.dim = dim
|
98 |
+
if second_dilation is None:
|
99 |
+
second_dilation = dilation
|
100 |
+
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout,
|
101 |
+
conv_kind=conv_kind, dilation=dilation, in_dim=in_dim, groups=groups,
|
102 |
+
second_dilation=second_dilation)
|
103 |
+
|
104 |
+
if self.in_dim is not None:
|
105 |
+
self.input_conv = nn.Conv2d(in_dim, dim, 1)
|
106 |
+
|
107 |
+
self.out_channnels = dim
|
108 |
+
|
109 |
+
def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout, conv_kind='default',
|
110 |
+
dilation=1, in_dim=None, groups=1, second_dilation=1):
|
111 |
+
conv_layer = get_conv_block_ctor(conv_kind)
|
112 |
+
|
113 |
+
conv_block = []
|
114 |
+
p = 0
|
115 |
+
if padding_type == 'reflect':
|
116 |
+
conv_block += [nn.ReflectionPad2d(dilation * 2)]
|
117 |
+
elif padding_type == 'replicate':
|
118 |
+
conv_block += [nn.ReplicationPad2d(dilation * 2)]
|
119 |
+
elif padding_type == 'zero':
|
120 |
+
p = dilation * 2
|
121 |
+
else:
|
122 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
123 |
+
|
124 |
+
if in_dim is None:
|
125 |
+
in_dim = dim
|
126 |
+
|
127 |
+
conv_block += [conv_layer(in_dim, dim, kernel_size=5, padding=p, dilation=dilation),
|
128 |
+
norm_layer(dim),
|
129 |
+
activation]
|
130 |
+
if use_dropout:
|
131 |
+
conv_block += [nn.Dropout(0.5)]
|
132 |
+
|
133 |
+
p = 0
|
134 |
+
if padding_type == 'reflect':
|
135 |
+
conv_block += [nn.ReflectionPad2d(second_dilation * 2)]
|
136 |
+
elif padding_type == 'replicate':
|
137 |
+
conv_block += [nn.ReplicationPad2d(second_dilation * 2)]
|
138 |
+
elif padding_type == 'zero':
|
139 |
+
p = second_dilation * 2
|
140 |
+
else:
|
141 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
142 |
+
conv_block += [conv_layer(dim, dim, kernel_size=5, padding=p, dilation=second_dilation, groups=groups),
|
143 |
+
norm_layer(dim)]
|
144 |
+
|
145 |
+
return nn.Sequential(*conv_block)
|
146 |
+
|
147 |
+
def forward(self, x):
|
148 |
+
x_before = x
|
149 |
+
if self.in_dim is not None:
|
150 |
+
x = self.input_conv(x)
|
151 |
+
out = x + self.conv_block(x_before)
|
152 |
+
return out
|
153 |
+
|
154 |
+
|
155 |
+
class MultidilatedResnetBlock(nn.Module):
|
156 |
+
def __init__(self, dim, padding_type, conv_layer, norm_layer, activation=nn.ReLU(True), use_dropout=False):
|
157 |
+
super().__init__()
|
158 |
+
self.conv_block = self.build_conv_block(dim, padding_type, conv_layer, norm_layer, activation, use_dropout)
|
159 |
+
|
160 |
+
def build_conv_block(self, dim, padding_type, conv_layer, norm_layer, activation, use_dropout, dilation=1):
|
161 |
+
conv_block = []
|
162 |
+
conv_block += [conv_layer(dim, dim, kernel_size=3, padding_mode=padding_type),
|
163 |
+
norm_layer(dim),
|
164 |
+
activation]
|
165 |
+
if use_dropout:
|
166 |
+
conv_block += [nn.Dropout(0.5)]
|
167 |
+
|
168 |
+
conv_block += [conv_layer(dim, dim, kernel_size=3, padding_mode=padding_type),
|
169 |
+
norm_layer(dim)]
|
170 |
+
|
171 |
+
return nn.Sequential(*conv_block)
|
172 |
+
|
173 |
+
def forward(self, x):
|
174 |
+
out = x + self.conv_block(x)
|
175 |
+
return out
|
176 |
+
|
177 |
+
|
178 |
+
class MultiDilatedGlobalGenerator(nn.Module):
|
179 |
+
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3,
|
180 |
+
n_blocks=3, norm_layer=nn.BatchNorm2d,
|
181 |
+
padding_type='reflect', conv_kind='default',
|
182 |
+
deconv_kind='convtranspose', activation=nn.ReLU(True),
|
183 |
+
up_norm_layer=nn.BatchNorm2d, affine=None, up_activation=nn.ReLU(True),
|
184 |
+
add_out_act=True, max_features=1024, multidilation_kwargs={},
|
185 |
+
ffc_positions=None, ffc_kwargs={}):
|
186 |
+
assert (n_blocks >= 0)
|
187 |
+
super().__init__()
|
188 |
+
|
189 |
+
conv_layer = get_conv_block_ctor(conv_kind)
|
190 |
+
resnet_conv_layer = functools.partial(get_conv_block_ctor('multidilated'), **multidilation_kwargs)
|
191 |
+
norm_layer = get_norm_layer(norm_layer)
|
192 |
+
if affine is not None:
|
193 |
+
norm_layer = partial(norm_layer, affine=affine)
|
194 |
+
up_norm_layer = get_norm_layer(up_norm_layer)
|
195 |
+
if affine is not None:
|
196 |
+
up_norm_layer = partial(up_norm_layer, affine=affine)
|
197 |
+
|
198 |
+
model = [nn.ReflectionPad2d(3),
|
199 |
+
conv_layer(input_nc, ngf, kernel_size=7, padding=0),
|
200 |
+
norm_layer(ngf),
|
201 |
+
activation]
|
202 |
+
|
203 |
+
identity = Identity()
|
204 |
+
### downsample
|
205 |
+
for i in range(n_downsampling):
|
206 |
+
mult = 2 ** i
|
207 |
+
|
208 |
+
model += [conv_layer(min(max_features, ngf * mult),
|
209 |
+
min(max_features, ngf * mult * 2),
|
210 |
+
kernel_size=3, stride=2, padding=1),
|
211 |
+
norm_layer(min(max_features, ngf * mult * 2)),
|
212 |
+
activation]
|
213 |
+
|
214 |
+
mult = 2 ** n_downsampling
|
215 |
+
feats_num_bottleneck = min(max_features, ngf * mult)
|
216 |
+
|
217 |
+
### resnet blocks
|
218 |
+
for i in range(n_blocks):
|
219 |
+
if ffc_positions is not None and i in ffc_positions:
|
220 |
+
model += [FFCResnetBlock(feats_num_bottleneck, padding_type, norm_layer, activation_layer=nn.ReLU,
|
221 |
+
inline=True, **ffc_kwargs)]
|
222 |
+
model += [MultidilatedResnetBlock(feats_num_bottleneck, padding_type=padding_type,
|
223 |
+
conv_layer=resnet_conv_layer, activation=activation,
|
224 |
+
norm_layer=norm_layer)]
|
225 |
+
|
226 |
+
### upsample
|
227 |
+
for i in range(n_downsampling):
|
228 |
+
mult = 2 ** (n_downsampling - i)
|
229 |
+
model += deconv_factory(deconv_kind, ngf, mult, up_norm_layer, up_activation, max_features)
|
230 |
+
model += [nn.ReflectionPad2d(3),
|
231 |
+
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
|
232 |
+
if add_out_act:
|
233 |
+
model.append(get_activation('tanh' if add_out_act is True else add_out_act))
|
234 |
+
self.model = nn.Sequential(*model)
|
235 |
+
|
236 |
+
def forward(self, input):
|
237 |
+
return self.model(input)
|
238 |
+
|
239 |
+
class ConfigGlobalGenerator(nn.Module):
|
240 |
+
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3,
|
241 |
+
n_blocks=3, norm_layer=nn.BatchNorm2d,
|
242 |
+
padding_type='reflect', conv_kind='default',
|
243 |
+
deconv_kind='convtranspose', activation=nn.ReLU(True),
|
244 |
+
up_norm_layer=nn.BatchNorm2d, affine=None, up_activation=nn.ReLU(True),
|
245 |
+
add_out_act=True, max_features=1024,
|
246 |
+
manual_block_spec=[],
|
247 |
+
resnet_block_kind='multidilatedresnetblock',
|
248 |
+
resnet_conv_kind='multidilated',
|
249 |
+
resnet_dilation=1,
|
250 |
+
multidilation_kwargs={}):
|
251 |
+
assert (n_blocks >= 0)
|
252 |
+
super().__init__()
|
253 |
+
|
254 |
+
conv_layer = get_conv_block_ctor(conv_kind)
|
255 |
+
resnet_conv_layer = functools.partial(get_conv_block_ctor(resnet_conv_kind), **multidilation_kwargs)
|
256 |
+
norm_layer = get_norm_layer(norm_layer)
|
257 |
+
if affine is not None:
|
258 |
+
norm_layer = partial(norm_layer, affine=affine)
|
259 |
+
up_norm_layer = get_norm_layer(up_norm_layer)
|
260 |
+
if affine is not None:
|
261 |
+
up_norm_layer = partial(up_norm_layer, affine=affine)
|
262 |
+
|
263 |
+
model = [nn.ReflectionPad2d(3),
|
264 |
+
conv_layer(input_nc, ngf, kernel_size=7, padding=0),
|
265 |
+
norm_layer(ngf),
|
266 |
+
activation]
|
267 |
+
|
268 |
+
identity = Identity()
|
269 |
+
|
270 |
+
### downsample
|
271 |
+
for i in range(n_downsampling):
|
272 |
+
mult = 2 ** i
|
273 |
+
model += [conv_layer(min(max_features, ngf * mult),
|
274 |
+
min(max_features, ngf * mult * 2),
|
275 |
+
kernel_size=3, stride=2, padding=1),
|
276 |
+
norm_layer(min(max_features, ngf * mult * 2)),
|
277 |
+
activation]
|
278 |
+
|
279 |
+
mult = 2 ** n_downsampling
|
280 |
+
feats_num_bottleneck = min(max_features, ngf * mult)
|
281 |
+
|
282 |
+
if len(manual_block_spec) == 0:
|
283 |
+
manual_block_spec = [
|
284 |
+
DotDict(lambda : None, {
|
285 |
+
'n_blocks': n_blocks,
|
286 |
+
'use_default': True})
|
287 |
+
]
|
288 |
+
|
289 |
+
### resnet blocks
|
290 |
+
for block_spec in manual_block_spec:
|
291 |
+
def make_and_add_blocks(model, block_spec):
|
292 |
+
block_spec = DotDict(lambda : None, block_spec)
|
293 |
+
if not block_spec.use_default:
|
294 |
+
resnet_conv_layer = functools.partial(get_conv_block_ctor(block_spec.resnet_conv_kind), **block_spec.multidilation_kwargs)
|
295 |
+
resnet_conv_kind = block_spec.resnet_conv_kind
|
296 |
+
resnet_block_kind = block_spec.resnet_block_kind
|
297 |
+
if block_spec.resnet_dilation is not None:
|
298 |
+
resnet_dilation = block_spec.resnet_dilation
|
299 |
+
for i in range(block_spec.n_blocks):
|
300 |
+
if resnet_block_kind == "multidilatedresnetblock":
|
301 |
+
model += [MultidilatedResnetBlock(feats_num_bottleneck, padding_type=padding_type,
|
302 |
+
conv_layer=resnet_conv_layer, activation=activation,
|
303 |
+
norm_layer=norm_layer)]
|
304 |
+
if resnet_block_kind == "resnetblock":
|
305 |
+
model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
|
306 |
+
conv_kind=resnet_conv_kind)]
|
307 |
+
if resnet_block_kind == "resnetblock5x5":
|
308 |
+
model += [ResnetBlock5x5(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
|
309 |
+
conv_kind=resnet_conv_kind)]
|
310 |
+
if resnet_block_kind == "resnetblockdwdil":
|
311 |
+
model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
|
312 |
+
conv_kind=resnet_conv_kind, dilation=resnet_dilation, second_dilation=resnet_dilation)]
|
313 |
+
make_and_add_blocks(model, block_spec)
|
314 |
+
|
315 |
+
### upsample
|
316 |
+
for i in range(n_downsampling):
|
317 |
+
mult = 2 ** (n_downsampling - i)
|
318 |
+
model += deconv_factory(deconv_kind, ngf, mult, up_norm_layer, up_activation, max_features)
|
319 |
+
model += [nn.ReflectionPad2d(3),
|
320 |
+
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
|
321 |
+
if add_out_act:
|
322 |
+
model.append(get_activation('tanh' if add_out_act is True else add_out_act))
|
323 |
+
self.model = nn.Sequential(*model)
|
324 |
+
|
325 |
+
def forward(self, input):
|
326 |
+
return self.model(input)
|
327 |
+
|
328 |
+
|
329 |
+
def make_dil_blocks(dilated_blocks_n, dilation_block_kind, dilated_block_kwargs):
|
330 |
+
blocks = []
|
331 |
+
for i in range(dilated_blocks_n):
|
332 |
+
if dilation_block_kind == 'simple':
|
333 |
+
blocks.append(ResnetBlock(**dilated_block_kwargs, dilation=2 ** (i + 1)))
|
334 |
+
elif dilation_block_kind == 'multi':
|
335 |
+
blocks.append(MultidilatedResnetBlock(**dilated_block_kwargs))
|
336 |
+
else:
|
337 |
+
raise ValueError(f'dilation_block_kind could not be "{dilation_block_kind}"')
|
338 |
+
return blocks
|
339 |
+
|
340 |
+
|
341 |
+
class GlobalGenerator(nn.Module):
|
342 |
+
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
|
343 |
+
padding_type='reflect', conv_kind='default', activation=nn.ReLU(True),
|
344 |
+
up_norm_layer=nn.BatchNorm2d, affine=None,
|
345 |
+
up_activation=nn.ReLU(True), dilated_blocks_n=0, dilated_blocks_n_start=0,
|
346 |
+
dilated_blocks_n_middle=0,
|
347 |
+
add_out_act=True,
|
348 |
+
max_features=1024, is_resblock_depthwise=False,
|
349 |
+
ffc_positions=None, ffc_kwargs={}, dilation=1, second_dilation=None,
|
350 |
+
dilation_block_kind='simple', multidilation_kwargs={}):
|
351 |
+
assert (n_blocks >= 0)
|
352 |
+
super().__init__()
|
353 |
+
|
354 |
+
conv_layer = get_conv_block_ctor(conv_kind)
|
355 |
+
norm_layer = get_norm_layer(norm_layer)
|
356 |
+
if affine is not None:
|
357 |
+
norm_layer = partial(norm_layer, affine=affine)
|
358 |
+
up_norm_layer = get_norm_layer(up_norm_layer)
|
359 |
+
if affine is not None:
|
360 |
+
up_norm_layer = partial(up_norm_layer, affine=affine)
|
361 |
+
|
362 |
+
if ffc_positions is not None:
|
363 |
+
ffc_positions = collections.Counter(ffc_positions)
|
364 |
+
|
365 |
+
model = [nn.ReflectionPad2d(3),
|
366 |
+
conv_layer(input_nc, ngf, kernel_size=7, padding=0),
|
367 |
+
norm_layer(ngf),
|
368 |
+
activation]
|
369 |
+
|
370 |
+
identity = Identity()
|
371 |
+
### downsample
|
372 |
+
for i in range(n_downsampling):
|
373 |
+
mult = 2 ** i
|
374 |
+
|
375 |
+
model += [conv_layer(min(max_features, ngf * mult),
|
376 |
+
min(max_features, ngf * mult * 2),
|
377 |
+
kernel_size=3, stride=2, padding=1),
|
378 |
+
norm_layer(min(max_features, ngf * mult * 2)),
|
379 |
+
activation]
|
380 |
+
|
381 |
+
mult = 2 ** n_downsampling
|
382 |
+
feats_num_bottleneck = min(max_features, ngf * mult)
|
383 |
+
|
384 |
+
dilated_block_kwargs = dict(dim=feats_num_bottleneck, padding_type=padding_type,
|
385 |
+
activation=activation, norm_layer=norm_layer)
|
386 |
+
if dilation_block_kind == 'simple':
|
387 |
+
dilated_block_kwargs['conv_kind'] = conv_kind
|
388 |
+
elif dilation_block_kind == 'multi':
|
389 |
+
dilated_block_kwargs['conv_layer'] = functools.partial(
|
390 |
+
get_conv_block_ctor('multidilated'), **multidilation_kwargs)
|
391 |
+
|
392 |
+
# dilated blocks at the start of the bottleneck sausage
|
393 |
+
if dilated_blocks_n_start is not None and dilated_blocks_n_start > 0:
|
394 |
+
model += make_dil_blocks(dilated_blocks_n_start, dilation_block_kind, dilated_block_kwargs)
|
395 |
+
|
396 |
+
# resnet blocks
|
397 |
+
for i in range(n_blocks):
|
398 |
+
# dilated blocks at the middle of the bottleneck sausage
|
399 |
+
if i == n_blocks // 2 and dilated_blocks_n_middle is not None and dilated_blocks_n_middle > 0:
|
400 |
+
model += make_dil_blocks(dilated_blocks_n_middle, dilation_block_kind, dilated_block_kwargs)
|
401 |
+
|
402 |
+
if ffc_positions is not None and i in ffc_positions:
|
403 |
+
for _ in range(ffc_positions[i]): # same position can occur more than once
|
404 |
+
model += [FFCResnetBlock(feats_num_bottleneck, padding_type, norm_layer, activation_layer=nn.ReLU,
|
405 |
+
inline=True, **ffc_kwargs)]
|
406 |
+
|
407 |
+
if is_resblock_depthwise:
|
408 |
+
resblock_groups = feats_num_bottleneck
|
409 |
+
else:
|
410 |
+
resblock_groups = 1
|
411 |
+
|
412 |
+
model += [ResnetBlock(feats_num_bottleneck, padding_type=padding_type, activation=activation,
|
413 |
+
norm_layer=norm_layer, conv_kind=conv_kind, groups=resblock_groups,
|
414 |
+
dilation=dilation, second_dilation=second_dilation)]
|
415 |
+
|
416 |
+
|
417 |
+
# dilated blocks at the end of the bottleneck sausage
|
418 |
+
if dilated_blocks_n is not None and dilated_blocks_n > 0:
|
419 |
+
model += make_dil_blocks(dilated_blocks_n, dilation_block_kind, dilated_block_kwargs)
|
420 |
+
|
421 |
+
# upsample
|
422 |
+
for i in range(n_downsampling):
|
423 |
+
mult = 2 ** (n_downsampling - i)
|
424 |
+
model += [nn.ConvTranspose2d(min(max_features, ngf * mult),
|
425 |
+
min(max_features, int(ngf * mult / 2)),
|
426 |
+
kernel_size=3, stride=2, padding=1, output_padding=1),
|
427 |
+
up_norm_layer(min(max_features, int(ngf * mult / 2))),
|
428 |
+
up_activation]
|
429 |
+
model += [nn.ReflectionPad2d(3),
|
430 |
+
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
|
431 |
+
if add_out_act:
|
432 |
+
model.append(get_activation('tanh' if add_out_act is True else add_out_act))
|
433 |
+
self.model = nn.Sequential(*model)
|
434 |
+
|
435 |
+
def forward(self, input):
|
436 |
+
return self.model(input)
|
437 |
+
|
438 |
+
|
439 |
+
class GlobalGeneratorGated(GlobalGenerator):
|
440 |
+
def __init__(self, *args, **kwargs):
|
441 |
+
real_kwargs=dict(
|
442 |
+
conv_kind='gated_bn_relu',
|
443 |
+
activation=nn.Identity(),
|
444 |
+
norm_layer=nn.Identity
|
445 |
+
)
|
446 |
+
real_kwargs.update(kwargs)
|
447 |
+
super().__init__(*args, **real_kwargs)
|
448 |
+
|
449 |
+
|
450 |
+
class GlobalGeneratorFromSuperChannels(nn.Module):
|
451 |
+
def __init__(self, input_nc, output_nc, n_downsampling, n_blocks, super_channels, norm_layer="bn", padding_type='reflect', add_out_act=True):
|
452 |
+
super().__init__()
|
453 |
+
self.n_downsampling = n_downsampling
|
454 |
+
norm_layer = get_norm_layer(norm_layer)
|
455 |
+
if type(norm_layer) == functools.partial:
|
456 |
+
use_bias = (norm_layer.func == nn.InstanceNorm2d)
|
457 |
+
else:
|
458 |
+
use_bias = (norm_layer == nn.InstanceNorm2d)
|
459 |
+
|
460 |
+
channels = self.convert_super_channels(super_channels)
|
461 |
+
self.channels = channels
|
462 |
+
|
463 |
+
model = [nn.ReflectionPad2d(3),
|
464 |
+
nn.Conv2d(input_nc, channels[0], kernel_size=7, padding=0, bias=use_bias),
|
465 |
+
norm_layer(channels[0]),
|
466 |
+
nn.ReLU(True)]
|
467 |
+
|
468 |
+
for i in range(n_downsampling): # add downsampling layers
|
469 |
+
mult = 2 ** i
|
470 |
+
model += [nn.Conv2d(channels[0+i], channels[1+i], kernel_size=3, stride=2, padding=1, bias=use_bias),
|
471 |
+
norm_layer(channels[1+i]),
|
472 |
+
nn.ReLU(True)]
|
473 |
+
|
474 |
+
mult = 2 ** n_downsampling
|
475 |
+
|
476 |
+
n_blocks1 = n_blocks // 3
|
477 |
+
n_blocks2 = n_blocks1
|
478 |
+
n_blocks3 = n_blocks - n_blocks1 - n_blocks2
|
479 |
+
|
480 |
+
for i in range(n_blocks1):
|
481 |
+
c = n_downsampling
|
482 |
+
dim = channels[c]
|
483 |
+
model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer)]
|
484 |
+
|
485 |
+
for i in range(n_blocks2):
|
486 |
+
c = n_downsampling+1
|
487 |
+
dim = channels[c]
|
488 |
+
kwargs = {}
|
489 |
+
if i == 0:
|
490 |
+
kwargs = {"in_dim": channels[c-1]}
|
491 |
+
model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer, **kwargs)]
|
492 |
+
|
493 |
+
for i in range(n_blocks3):
|
494 |
+
c = n_downsampling+2
|
495 |
+
dim = channels[c]
|
496 |
+
kwargs = {}
|
497 |
+
if i == 0:
|
498 |
+
kwargs = {"in_dim": channels[c-1]}
|
499 |
+
model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer, **kwargs)]
|
500 |
+
|
501 |
+
for i in range(n_downsampling): # add upsampling layers
|
502 |
+
mult = 2 ** (n_downsampling - i)
|
503 |
+
model += [nn.ConvTranspose2d(channels[n_downsampling+3+i],
|
504 |
+
channels[n_downsampling+3+i+1],
|
505 |
+
kernel_size=3, stride=2,
|
506 |
+
padding=1, output_padding=1,
|
507 |
+
bias=use_bias),
|
508 |
+
norm_layer(channels[n_downsampling+3+i+1]),
|
509 |
+
nn.ReLU(True)]
|
510 |
+
model += [nn.ReflectionPad2d(3)]
|
511 |
+
model += [nn.Conv2d(channels[2*n_downsampling+3], output_nc, kernel_size=7, padding=0)]
|
512 |
+
|
513 |
+
if add_out_act:
|
514 |
+
model.append(get_activation('tanh' if add_out_act is True else add_out_act))
|
515 |
+
self.model = nn.Sequential(*model)
|
516 |
+
|
517 |
+
def convert_super_channels(self, super_channels):
|
518 |
+
n_downsampling = self.n_downsampling
|
519 |
+
result = []
|
520 |
+
cnt = 0
|
521 |
+
|
522 |
+
if n_downsampling == 2:
|
523 |
+
N1 = 10
|
524 |
+
elif n_downsampling == 3:
|
525 |
+
N1 = 13
|
526 |
+
else:
|
527 |
+
raise NotImplementedError
|
528 |
+
|
529 |
+
for i in range(0, N1):
|
530 |
+
if i in [1,4,7,10]:
|
531 |
+
channel = super_channels[cnt] * (2 ** cnt)
|
532 |
+
config = {'channel': channel}
|
533 |
+
result.append(channel)
|
534 |
+
logging.info(f"Downsample channels {result[-1]}")
|
535 |
+
cnt += 1
|
536 |
+
|
537 |
+
for i in range(3):
|
538 |
+
for counter, j in enumerate(range(N1 + i * 3, N1 + 3 + i * 3)):
|
539 |
+
if len(super_channels) == 6:
|
540 |
+
channel = super_channels[3] * 4
|
541 |
+
else:
|
542 |
+
channel = super_channels[i + 3] * 4
|
543 |
+
config = {'channel': channel}
|
544 |
+
if counter == 0:
|
545 |
+
result.append(channel)
|
546 |
+
logging.info(f"Bottleneck channels {result[-1]}")
|
547 |
+
cnt = 2
|
548 |
+
|
549 |
+
for i in range(N1+9, N1+21):
|
550 |
+
if i in [22, 25,28]:
|
551 |
+
cnt -= 1
|
552 |
+
if len(super_channels) == 6:
|
553 |
+
channel = super_channels[5 - cnt] * (2 ** cnt)
|
554 |
+
else:
|
555 |
+
channel = super_channels[7 - cnt] * (2 ** cnt)
|
556 |
+
result.append(int(channel))
|
557 |
+
logging.info(f"Upsample channels {result[-1]}")
|
558 |
+
return result
|
559 |
+
|
560 |
+
def forward(self, input):
|
561 |
+
return self.model(input)
|
562 |
+
|
563 |
+
|
564 |
+
# Defines the PatchGAN discriminator with the specified arguments.
|
565 |
+
class NLayerDiscriminator(BaseDiscriminator):
|
566 |
+
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,):
|
567 |
+
super().__init__()
|
568 |
+
self.n_layers = n_layers
|
569 |
+
|
570 |
+
kw = 4
|
571 |
+
padw = int(np.ceil((kw-1.0)/2))
|
572 |
+
sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
|
573 |
+
nn.LeakyReLU(0.2, True)]]
|
574 |
+
|
575 |
+
nf = ndf
|
576 |
+
for n in range(1, n_layers):
|
577 |
+
nf_prev = nf
|
578 |
+
nf = min(nf * 2, 512)
|
579 |
+
|
580 |
+
cur_model = []
|
581 |
+
cur_model += [
|
582 |
+
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
|
583 |
+
norm_layer(nf),
|
584 |
+
nn.LeakyReLU(0.2, True)
|
585 |
+
]
|
586 |
+
sequence.append(cur_model)
|
587 |
+
|
588 |
+
nf_prev = nf
|
589 |
+
nf = min(nf * 2, 512)
|
590 |
+
|
591 |
+
cur_model = []
|
592 |
+
cur_model += [
|
593 |
+
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
|
594 |
+
norm_layer(nf),
|
595 |
+
nn.LeakyReLU(0.2, True)
|
596 |
+
]
|
597 |
+
sequence.append(cur_model)
|
598 |
+
|
599 |
+
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
|
600 |
+
|
601 |
+
for n in range(len(sequence)):
|
602 |
+
setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
|
603 |
+
|
604 |
+
def get_all_activations(self, x):
|
605 |
+
res = [x]
|
606 |
+
for n in range(self.n_layers + 2):
|
607 |
+
model = getattr(self, 'model' + str(n))
|
608 |
+
res.append(model(res[-1]))
|
609 |
+
return res[1:]
|
610 |
+
|
611 |
+
def forward(self, x):
|
612 |
+
act = self.get_all_activations(x)
|
613 |
+
return act[-1], act[:-1]
|
614 |
+
|
615 |
+
|
616 |
+
class MultidilatedNLayerDiscriminator(BaseDiscriminator):
|
617 |
+
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, multidilation_kwargs={}):
|
618 |
+
super().__init__()
|
619 |
+
self.n_layers = n_layers
|
620 |
+
|
621 |
+
kw = 4
|
622 |
+
padw = int(np.ceil((kw-1.0)/2))
|
623 |
+
sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
|
624 |
+
nn.LeakyReLU(0.2, True)]]
|
625 |
+
|
626 |
+
nf = ndf
|
627 |
+
for n in range(1, n_layers):
|
628 |
+
nf_prev = nf
|
629 |
+
nf = min(nf * 2, 512)
|
630 |
+
|
631 |
+
cur_model = []
|
632 |
+
cur_model += [
|
633 |
+
MultidilatedConv(nf_prev, nf, kernel_size=kw, stride=2, padding=[2, 3], **multidilation_kwargs),
|
634 |
+
norm_layer(nf),
|
635 |
+
nn.LeakyReLU(0.2, True)
|
636 |
+
]
|
637 |
+
sequence.append(cur_model)
|
638 |
+
|
639 |
+
nf_prev = nf
|
640 |
+
nf = min(nf * 2, 512)
|
641 |
+
|
642 |
+
cur_model = []
|
643 |
+
cur_model += [
|
644 |
+
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
|
645 |
+
norm_layer(nf),
|
646 |
+
nn.LeakyReLU(0.2, True)
|
647 |
+
]
|
648 |
+
sequence.append(cur_model)
|
649 |
+
|
650 |
+
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
|
651 |
+
|
652 |
+
for n in range(len(sequence)):
|
653 |
+
setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
|
654 |
+
|
655 |
+
def get_all_activations(self, x):
|
656 |
+
res = [x]
|
657 |
+
for n in range(self.n_layers + 2):
|
658 |
+
model = getattr(self, 'model' + str(n))
|
659 |
+
res.append(model(res[-1]))
|
660 |
+
return res[1:]
|
661 |
+
|
662 |
+
def forward(self, x):
|
663 |
+
act = self.get_all_activations(x)
|
664 |
+
return act[-1], act[:-1]
|
665 |
+
|
666 |
+
|
667 |
+
class NLayerDiscriminatorAsGen(NLayerDiscriminator):
|
668 |
+
def forward(self, x):
|
669 |
+
return super().forward(x)[0]
|
inpaint/saicinpainting/training/modules/spatial_transform.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from kornia.geometry.transform import rotate
|
5 |
+
|
6 |
+
|
7 |
+
class LearnableSpatialTransformWrapper(nn.Module):
|
8 |
+
def __init__(self, impl, pad_coef=0.5, angle_init_range=80, train_angle=True):
|
9 |
+
super().__init__()
|
10 |
+
self.impl = impl
|
11 |
+
self.angle = torch.rand(1) * angle_init_range
|
12 |
+
if train_angle:
|
13 |
+
self.angle = nn.Parameter(self.angle, requires_grad=True)
|
14 |
+
self.pad_coef = pad_coef
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
if torch.is_tensor(x):
|
18 |
+
return self.inverse_transform(self.impl(self.transform(x)), x)
|
19 |
+
elif isinstance(x, tuple):
|
20 |
+
x_trans = tuple(self.transform(elem) for elem in x)
|
21 |
+
y_trans = self.impl(x_trans)
|
22 |
+
return tuple(self.inverse_transform(elem, orig_x) for elem, orig_x in zip(y_trans, x))
|
23 |
+
else:
|
24 |
+
raise ValueError(f'Unexpected input type {type(x)}')
|
25 |
+
|
26 |
+
def transform(self, x):
|
27 |
+
height, width = x.shape[2:]
|
28 |
+
pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef)
|
29 |
+
x_padded = F.pad(x, [pad_w, pad_w, pad_h, pad_h], mode='reflect')
|
30 |
+
x_padded_rotated = rotate(x_padded, angle=self.angle.to(x_padded))
|
31 |
+
return x_padded_rotated
|
32 |
+
|
33 |
+
def inverse_transform(self, y_padded_rotated, orig_x):
|
34 |
+
height, width = orig_x.shape[2:]
|
35 |
+
pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef)
|
36 |
+
|
37 |
+
y_padded = rotate(y_padded_rotated, angle=-self.angle.to(y_padded_rotated))
|
38 |
+
y_height, y_width = y_padded.shape[2:]
|
39 |
+
y = y_padded[:, :, pad_h : y_height - pad_h, pad_w : y_width - pad_w]
|
40 |
+
return y
|
41 |
+
|
42 |
+
|
43 |
+
if __name__ == '__main__':
|
44 |
+
layer = LearnableSpatialTransformWrapper(nn.Identity())
|
45 |
+
x = torch.arange(2* 3 * 15 * 15).view(2, 3, 15, 15).float()
|
46 |
+
y = layer(x)
|
47 |
+
assert x.shape == y.shape
|
48 |
+
assert torch.allclose(x[:, :, 1:, 1:][:, :, :-1, :-1], y[:, :, 1:, 1:][:, :, :-1, :-1])
|
49 |
+
print('all ok')
|
inpaint/saicinpainting/training/modules/squeeze_excitation.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
|
4 |
+
class SELayer(nn.Module):
|
5 |
+
def __init__(self, channel, reduction=16):
|
6 |
+
super(SELayer, self).__init__()
|
7 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
8 |
+
self.fc = nn.Sequential(
|
9 |
+
nn.Linear(channel, channel // reduction, bias=False),
|
10 |
+
nn.ReLU(inplace=True),
|
11 |
+
nn.Linear(channel // reduction, channel, bias=False),
|
12 |
+
nn.Sigmoid()
|
13 |
+
)
|
14 |
+
|
15 |
+
def forward(self, x):
|
16 |
+
b, c, _, _ = x.size()
|
17 |
+
y = self.avg_pool(x).view(b, c)
|
18 |
+
y = self.fc(y).view(b, c, 1, 1)
|
19 |
+
res = x * y.expand_as(x)
|
20 |
+
return res
|
inpaint/saicinpainting/training/trainers/__init__.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from saicinpainting.training.trainers.default import DefaultInpaintingTrainingModule
|
3 |
+
|
4 |
+
|
5 |
+
def get_training_model_class(kind):
|
6 |
+
if kind == 'default':
|
7 |
+
return DefaultInpaintingTrainingModule
|
8 |
+
|
9 |
+
raise ValueError(f'Unknown trainer module {kind}')
|
10 |
+
|
11 |
+
|
12 |
+
def make_training_model(config):
|
13 |
+
kind = config.training_model.kind
|
14 |
+
kwargs = dict(config.training_model)
|
15 |
+
kwargs.pop('kind')
|
16 |
+
kwargs['use_ddp'] = config.trainer.kwargs.get('accelerator', None) == 'ddp'
|
17 |
+
cls = get_training_model_class(kind)
|
18 |
+
return cls(config, **kwargs)
|
19 |
+
|
20 |
+
|
21 |
+
def load_checkpoint(train_config, path, map_location='cuda', strict=True):
|
22 |
+
model: torch.nn.Module = make_training_model(train_config)
|
23 |
+
state = torch.load(path, map_location=map_location)
|
24 |
+
model.load_state_dict(state['state_dict'], strict=strict)
|
25 |
+
model.on_load_checkpoint(state)
|
26 |
+
return model
|
inpaint/saicinpainting/training/trainers/base.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as ptl
|
2 |
+
from inpaint.saicinpainting.training.modules import make_generator
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
+
class BaseInpaintingTrainingModule(ptl.LightningModule):
|
8 |
+
def __init__(self, config, use_ddp, *args, predict_only=False, visualize_each_iters=100,
|
9 |
+
average_generator=False, generator_avg_beta=0.999, average_generator_start_step=30000,
|
10 |
+
average_generator_period=10, store_discr_outputs_for_vis=False,
|
11 |
+
**kwargs):
|
12 |
+
super().__init__(*args, **kwargs)
|
13 |
+
|
14 |
+
self.config = config
|
15 |
+
self.generator = make_generator(config, **self.config.generator)
|
16 |
+
self.use_ddp = use_ddp
|
17 |
+
self.visualize_each_iters = visualize_each_iters
|
18 |
+
|
19 |
+
|
inpaint/saicinpainting/training/trainers/default.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from saicinpainting.training.trainers.base import BaseInpaintingTrainingModule
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
class DefaultInpaintingTrainingModule(BaseInpaintingTrainingModule):
|
10 |
+
def __init__(self, *args, concat_mask=True, rescale_scheduler_kwargs=None, image_to_discriminator='predicted_image',
|
11 |
+
add_noise_kwargs=None, noise_fill_hole=False, const_area_crop_kwargs=None,
|
12 |
+
distance_weighter_kwargs=None, distance_weighted_mask_for_discr=False,
|
13 |
+
fake_fakes_proba=0, fake_fakes_generator_kwargs=None,
|
14 |
+
**kwargs):
|
15 |
+
super().__init__(*args, **kwargs)
|
16 |
+
self.concat_mask = concat_mask
|
17 |
+
self.image_to_discriminator = image_to_discriminator
|
18 |
+
self.add_noise_kwargs = add_noise_kwargs
|
19 |
+
self.noise_fill_hole = noise_fill_hole
|
20 |
+
self.const_area_crop_kwargs = const_area_crop_kwargs
|
21 |
+
# print(distance_weighter_kwargs)
|
22 |
+
self.refine_mask_for_losses = None
|
23 |
+
self.distance_weighted_mask_for_discr = distance_weighted_mask_for_discr
|
24 |
+
|
25 |
+
self.fake_fakes_proba = fake_fakes_proba
|
26 |
+
|
27 |
+
def forward(self, batch):
|
28 |
+
|
29 |
+
img = batch['image']
|
30 |
+
mask = batch['mask']
|
31 |
+
|
32 |
+
masked_img = img * (1 - mask)
|
33 |
+
if self.concat_mask:
|
34 |
+
masked_img = torch.cat([masked_img, mask], dim=1)
|
35 |
+
|
36 |
+
batch['predicted_image'] = self.generator(masked_img)
|
37 |
+
batch['inpainted'] = mask * batch['predicted_image'] + (1 - mask) * batch['image']
|
38 |
+
if self.fake_fakes_proba > 1e-3:
|
39 |
+
if self.training and torch.rand(1).item() < self.fake_fakes_proba:
|
40 |
+
batch['fake_fakes'], batch['fake_fakes_masks'] = self.fake_fakes_gen(img, mask)
|
41 |
+
batch['use_fake_fakes'] = True
|
42 |
+
else:
|
43 |
+
batch['fake_fakes'] = torch.zeros_like(img)
|
44 |
+
batch['fake_fakes_masks'] = torch.zeros_like(mask)
|
45 |
+
batch['use_fake_fakes'] = False
|
46 |
+
|
47 |
+
batch['mask_for_losses'] = self.refine_mask_for_losses(img, batch['predicted_image'], mask) \
|
48 |
+
if self.refine_mask_for_losses is not None and self.training \
|
49 |
+
else mask
|
50 |
+
|
51 |
+
return batch
|
52 |
+
|
53 |
+
|
sod/PGNet.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from .Res import resnet18
|
6 |
+
from .Swin import Swintransformer
|
7 |
+
Act = nn.ReLU
|
8 |
+
|
9 |
+
|
10 |
+
def weight_init(module):
|
11 |
+
for n, m in module.named_children():
|
12 |
+
if isinstance(m, nn.Conv2d):
|
13 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
|
14 |
+
if m.bias is not None:
|
15 |
+
nn.init.zeros_(m.bias)
|
16 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d,nn.BatchNorm1d)):
|
17 |
+
nn.init.ones_(m.weight)
|
18 |
+
if m.bias is not None:
|
19 |
+
nn.init.zeros_(m.bias)
|
20 |
+
elif isinstance(m, nn.Linear):
|
21 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
|
22 |
+
if m.bias is not None:
|
23 |
+
nn.init.zeros_(m.bias)
|
24 |
+
elif isinstance(m, nn.Sequential):
|
25 |
+
weight_init(m)
|
26 |
+
elif isinstance(m, nn.LayerNorm):
|
27 |
+
nn.init.constant_(m.bias, 0)
|
28 |
+
nn.init.constant_(m.weight, 1.0)
|
29 |
+
elif isinstance(m, (nn.ReLU,Act,nn.AdaptiveAvgPool2d,nn.Softmax)):
|
30 |
+
pass
|
31 |
+
else:
|
32 |
+
m.initialize()
|
33 |
+
|
34 |
+
|
35 |
+
class Grafting(nn.Module):
|
36 |
+
def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None):
|
37 |
+
super().__init__()
|
38 |
+
self.num_heads = num_heads
|
39 |
+
head_dim = dim // num_heads
|
40 |
+
self.scale = qk_scale or head_dim ** -0.5
|
41 |
+
self.k = nn.Linear(dim, dim , bias=qkv_bias)
|
42 |
+
self.qv = nn.Linear(dim, dim * 2, bias=qkv_bias)
|
43 |
+
self.proj = nn.Linear(dim, dim)
|
44 |
+
self.act = nn.ReLU(inplace=True)
|
45 |
+
self.conv = nn.Conv2d(8,8,kernel_size=3, stride=1, padding=1)
|
46 |
+
self.lnx = nn.LayerNorm(64)
|
47 |
+
self.lny = nn.LayerNorm(64)
|
48 |
+
self.bn = nn.BatchNorm2d(8)
|
49 |
+
self.conv2 = nn.Sequential(
|
50 |
+
nn.Conv2d(64,64,kernel_size=3, stride=1, padding=1),
|
51 |
+
nn.BatchNorm2d(64),
|
52 |
+
nn.ReLU(inplace=True),
|
53 |
+
nn.Conv2d(64,64,kernel_size=3, stride=1, padding=1),
|
54 |
+
nn.BatchNorm2d(64),
|
55 |
+
nn.ReLU(inplace=True)
|
56 |
+
)
|
57 |
+
def forward(self, x, y):
|
58 |
+
batch_size = x.shape[0]
|
59 |
+
chanel = x.shape[1]
|
60 |
+
sc = x
|
61 |
+
x = x.view(batch_size, chanel, -1).permute(0, 2, 1)
|
62 |
+
sc1 = x
|
63 |
+
x = self.lnx(x)
|
64 |
+
y = y.view(batch_size, chanel, -1).permute(0, 2, 1)
|
65 |
+
y = self.lny(y)
|
66 |
+
|
67 |
+
B, N, C = x.shape
|
68 |
+
y_k = self.k(y).reshape(B, N, 1, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
69 |
+
x_qv= self.qv(x).reshape(B,N,2,self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
70 |
+
x_q, x_v = x_qv[0], x_qv[1]
|
71 |
+
y_k = y_k[0]
|
72 |
+
attn = (x_q @ y_k.transpose(-2, -1)) * self.scale
|
73 |
+
attn = attn.softmax(dim=-1)
|
74 |
+
x = (attn @ x_v).transpose(1, 2).reshape(B, N, C)
|
75 |
+
|
76 |
+
x = self.proj(x)
|
77 |
+
x = (x+sc1)
|
78 |
+
|
79 |
+
x = x.permute(0,2,1)
|
80 |
+
x = x.view(batch_size,chanel,*sc.size()[2:])
|
81 |
+
x = self.conv2(x)+x
|
82 |
+
return x,self.act(self.bn(self.conv(attn+attn.transpose(-1,-2))))
|
83 |
+
|
84 |
+
|
85 |
+
def initialize(self):
|
86 |
+
weight_init(self)
|
87 |
+
|
88 |
+
class DB1(nn.Module):
|
89 |
+
def __init__(self,inplanes,outplanes):
|
90 |
+
super(DB1,self).__init__()
|
91 |
+
self.squeeze1 = nn.Sequential(
|
92 |
+
nn.Conv2d(inplanes, outplanes,kernel_size=1,stride=1,padding=0),
|
93 |
+
nn.BatchNorm2d(64),
|
94 |
+
nn.ReLU(inplace=True)
|
95 |
+
)
|
96 |
+
self.squeeze2 = nn.Sequential(
|
97 |
+
nn.Conv2d(64, 64, kernel_size=3,stride=1,dilation=2,padding=2),
|
98 |
+
nn.BatchNorm2d(64),
|
99 |
+
nn.ReLU(inplace=True)
|
100 |
+
)
|
101 |
+
|
102 |
+
def forward(self, x):
|
103 |
+
z = self.squeeze2(self.squeeze1(x))
|
104 |
+
return z,z
|
105 |
+
|
106 |
+
def initialize(self):
|
107 |
+
weight_init(self)
|
108 |
+
|
109 |
+
class DB2(nn.Module):
|
110 |
+
def __init__(self,inplanes,outplanes):
|
111 |
+
super(DB2,self).__init__()
|
112 |
+
self.short_cut = nn.Conv2d(outplanes, outplanes, kernel_size=1, stride=1, padding=0)
|
113 |
+
self.conv = nn.Sequential(
|
114 |
+
nn.Conv2d(inplanes+outplanes,outplanes,kernel_size=3, stride=1, padding=1),
|
115 |
+
nn.BatchNorm2d(outplanes),
|
116 |
+
nn.ReLU(inplace=True),
|
117 |
+
nn.Conv2d(outplanes,outplanes,kernel_size=3, stride=1, padding=1),
|
118 |
+
nn.BatchNorm2d(outplanes),
|
119 |
+
nn.ReLU(inplace=True)
|
120 |
+
)
|
121 |
+
self.conv2 = nn.Sequential(
|
122 |
+
nn.Conv2d(outplanes,outplanes,kernel_size=3, stride=1, padding=1),
|
123 |
+
nn.BatchNorm2d(outplanes),
|
124 |
+
nn.ReLU(inplace=True),
|
125 |
+
nn.Conv2d(outplanes,outplanes,kernel_size=3, stride=1, padding=1),
|
126 |
+
nn.BatchNorm2d(outplanes),
|
127 |
+
nn.ReLU(inplace=True)
|
128 |
+
)
|
129 |
+
|
130 |
+
def forward(self,x,z):
|
131 |
+
z = F.interpolate(z,size=x.size()[2:],mode='bilinear',align_corners=True)
|
132 |
+
p = self.conv(torch.cat((x,z),1))
|
133 |
+
sc = self.short_cut(z)
|
134 |
+
p = p+sc
|
135 |
+
p2 = self.conv2(p)
|
136 |
+
p = p+p2
|
137 |
+
return p,p
|
138 |
+
|
139 |
+
def initialize(self):
|
140 |
+
weight_init(self)
|
141 |
+
|
142 |
+
class DB3(nn.Module):
|
143 |
+
def __init__(self) -> None:
|
144 |
+
super(DB3,self).__init__()
|
145 |
+
|
146 |
+
self.db2 = DB2(64,64)
|
147 |
+
|
148 |
+
self.conv3x3 = nn.Sequential(
|
149 |
+
nn.Conv2d(64,64,kernel_size=3, stride=1, padding=1),
|
150 |
+
nn.BatchNorm2d(64),
|
151 |
+
nn.ReLU(inplace=True)
|
152 |
+
)
|
153 |
+
self.sqz_r4 = nn.Sequential(
|
154 |
+
nn.Conv2d(256, 64, kernel_size=3,stride=1,dilation=1,padding=1),
|
155 |
+
nn.BatchNorm2d(64),
|
156 |
+
nn.ReLU(inplace=True)
|
157 |
+
)
|
158 |
+
|
159 |
+
self.sqz_s1=nn.Sequential(
|
160 |
+
nn.Conv2d(128, 64, kernel_size=3,stride=1,dilation=1,padding=1),
|
161 |
+
nn.BatchNorm2d(64),
|
162 |
+
nn.ReLU(inplace=True)
|
163 |
+
)
|
164 |
+
def forward(self,s,r,up):
|
165 |
+
up = F.interpolate(up,size=s.size()[2:],mode='bilinear',align_corners=True)
|
166 |
+
s = self.sqz_s1(s)
|
167 |
+
r = self.sqz_r4(r)
|
168 |
+
sr = self.conv3x3(s+r)
|
169 |
+
out,_ =self.db2(sr,up)
|
170 |
+
return out,out
|
171 |
+
def initialize(self):
|
172 |
+
weight_init(self)
|
173 |
+
|
174 |
+
|
175 |
+
|
176 |
+
class decoder(nn.Module):
|
177 |
+
def __init__(self) -> None:
|
178 |
+
super(decoder,self).__init__()
|
179 |
+
self.sqz_s2=nn.Sequential(
|
180 |
+
nn.Conv2d(256, 64, kernel_size=3,stride=1,dilation=1,padding=1),
|
181 |
+
nn.BatchNorm2d(64),
|
182 |
+
nn.ReLU(inplace=True)
|
183 |
+
)
|
184 |
+
self.sqz_r5 = nn.Sequential(
|
185 |
+
nn.Conv2d(512, 64, kernel_size=3,stride=1,dilation=1,padding=1),
|
186 |
+
nn.BatchNorm2d(64),
|
187 |
+
nn.ReLU(inplace=True)
|
188 |
+
)
|
189 |
+
|
190 |
+
self.GF = Grafting(64,num_heads=8)
|
191 |
+
self.d1 = DB1(512,64)
|
192 |
+
self.d2 = DB2(512,64)
|
193 |
+
self.d3 = DB2(64,64)
|
194 |
+
self.d4 = DB3()
|
195 |
+
self.d5 = DB2(128,64)
|
196 |
+
self.d6 = DB2(64,64)
|
197 |
+
|
198 |
+
def forward(self,s1,s2,s3,s4,r2,r3,r4,r5):
|
199 |
+
r5 = F.interpolate(r5,size = s2.size()[2:],mode='bilinear',align_corners=True)
|
200 |
+
s1 = F.interpolate(s1,size = r4.size()[2:],mode='bilinear',align_corners=True)
|
201 |
+
|
202 |
+
s4_,_ = self.d1(s4)
|
203 |
+
s3_,_ = self.d2(s3,s4_)
|
204 |
+
|
205 |
+
s2_ = self.sqz_s2(s2)
|
206 |
+
r5_= self.sqz_r5(r5)
|
207 |
+
graft_feature_r5,cam = self.GF(r5_,s2_)
|
208 |
+
|
209 |
+
graft_feature_r5_,_=self.d3(graft_feature_r5,s3_)
|
210 |
+
|
211 |
+
graft_feature_r4,_=self.d4(s1,r4,graft_feature_r5_)
|
212 |
+
|
213 |
+
r3_,_ = self.d5(r3,graft_feature_r4)
|
214 |
+
|
215 |
+
r2_,_ = self.d6(r2,r3_)
|
216 |
+
|
217 |
+
return r2_,cam,r5_,s2_
|
218 |
+
|
219 |
+
def initialize(self):
|
220 |
+
weight_init(self)
|
221 |
+
|
222 |
+
|
223 |
+
|
224 |
+
|
225 |
+
class PGNet(nn.Module):
|
226 |
+
def __init__(self, cfg=None):
|
227 |
+
super(PGNet, self).__init__()
|
228 |
+
self.cfg = cfg
|
229 |
+
self.decoder = decoder()
|
230 |
+
self.linear1 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1)
|
231 |
+
self.linear2 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1)
|
232 |
+
self.linear3 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1)
|
233 |
+
self.conv = nn.Conv2d(8,1,kernel_size=3, stride=1, padding=1)
|
234 |
+
|
235 |
+
|
236 |
+
if self.cfg is None or self.cfg.snapshot is None:
|
237 |
+
weight_init(self)
|
238 |
+
|
239 |
+
self.resnet = resnet18()
|
240 |
+
self.swin = Swintransformer(224)
|
241 |
+
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
242 |
+
self.swin.load_state_dict(torch.load('sod/weights/swin224.pth', map_location=device)['model'],strict=False)
|
243 |
+
self.resnet.load_state_dict(torch.load('sod/weights/resnet18.pth', map_location=device),strict=False)
|
244 |
+
|
245 |
+
if self.cfg is not None and self.cfg.snapshot:
|
246 |
+
print('load checkpoint')
|
247 |
+
pretrain=torch.load(self.cfg.snapshot, map_location=device)
|
248 |
+
new_state_dict = {}
|
249 |
+
for k,v in pretrain.items():
|
250 |
+
new_state_dict[k[7:]] = v
|
251 |
+
self.load_state_dict(new_state_dict, strict=False)
|
252 |
+
|
253 |
+
def forward(self, x,shape=None,mask=None):
|
254 |
+
shape = x.size()[2:] if shape is None else shape
|
255 |
+
y = F.interpolate(x, size=(224,224), mode='bilinear',align_corners=True)
|
256 |
+
|
257 |
+
r2,r3,r4,r5 = self.resnet(x)
|
258 |
+
s1,s2,s3,s4 = self.swin(y)
|
259 |
+
r2_,attmap,r5_,s2_ = self.decoder(s1,s2,s3,s4,r2,r3,r4,r5)
|
260 |
+
|
261 |
+
pred1 = F.interpolate(self.linear1(r2_), size=shape, mode='bilinear')
|
262 |
+
wr = F.interpolate(self.linear2(r5_), size=(28,28), mode='bilinear')
|
263 |
+
ws = F.interpolate(self.linear3(s2_), size=(28,28), mode='bilinear')
|
264 |
+
|
265 |
+
|
266 |
+
return pred1,wr,ws,self.conv(attmap)
|
267 |
+
|
268 |
+
|
269 |
+
|
270 |
+
|
sod/Res.py
ADDED
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
6 |
+
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
|
7 |
+
'wide_resnet50_2', 'wide_resnet101_2']
|
8 |
+
|
9 |
+
|
10 |
+
model_urls = {
|
11 |
+
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
12 |
+
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
13 |
+
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
14 |
+
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
15 |
+
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
16 |
+
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
17 |
+
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
18 |
+
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
|
19 |
+
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
|
20 |
+
}
|
21 |
+
|
22 |
+
|
23 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
24 |
+
"""3x3 convolution with padding"""
|
25 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
26 |
+
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
27 |
+
|
28 |
+
|
29 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
30 |
+
"""1x1 convolution"""
|
31 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
32 |
+
|
33 |
+
|
34 |
+
class BasicBlock(nn.Module):
|
35 |
+
expansion = 1
|
36 |
+
|
37 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
38 |
+
base_width=64, dilation=1, norm_layer=None):
|
39 |
+
super(BasicBlock, self).__init__()
|
40 |
+
if norm_layer is None:
|
41 |
+
norm_layer = nn.BatchNorm2d
|
42 |
+
if groups != 1 or base_width != 64:
|
43 |
+
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
44 |
+
if dilation > 1:
|
45 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
46 |
+
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
47 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
48 |
+
self.bn1 = norm_layer(planes)
|
49 |
+
self.relu = nn.ReLU(inplace=True)
|
50 |
+
self.conv2 = conv3x3(planes, planes)
|
51 |
+
self.bn2 = norm_layer(planes)
|
52 |
+
self.downsample = downsample
|
53 |
+
self.stride = stride
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
identity = x
|
57 |
+
|
58 |
+
out = self.conv1(x)
|
59 |
+
out = self.bn1(out)
|
60 |
+
out = self.relu(out)
|
61 |
+
|
62 |
+
out = self.conv2(out)
|
63 |
+
out = self.bn2(out)
|
64 |
+
|
65 |
+
if self.downsample is not None:
|
66 |
+
identity = self.downsample(x)
|
67 |
+
|
68 |
+
out += identity
|
69 |
+
out = self.relu(out)
|
70 |
+
|
71 |
+
return out
|
72 |
+
def initialize(self):
|
73 |
+
weight_init(self)
|
74 |
+
|
75 |
+
|
76 |
+
class Bottleneck(nn.Module):
|
77 |
+
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
78 |
+
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
79 |
+
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
80 |
+
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
81 |
+
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
82 |
+
|
83 |
+
expansion = 4
|
84 |
+
|
85 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
86 |
+
base_width=64, dilation=1, norm_layer=None):
|
87 |
+
super(Bottleneck, self).__init__()
|
88 |
+
if norm_layer is None:
|
89 |
+
norm_layer = nn.BatchNorm2d
|
90 |
+
width = int(planes * (base_width / 64.)) * groups
|
91 |
+
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
92 |
+
self.conv1 = conv1x1(inplanes, width)
|
93 |
+
self.bn1 = norm_layer(width)
|
94 |
+
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
95 |
+
self.bn2 = norm_layer(width)
|
96 |
+
self.conv3 = conv1x1(width, planes * self.expansion)
|
97 |
+
self.bn3 = norm_layer(planes * self.expansion)
|
98 |
+
self.relu = nn.ReLU(inplace=True)
|
99 |
+
self.downsample = downsample
|
100 |
+
self.stride = stride
|
101 |
+
|
102 |
+
def forward(self, x):
|
103 |
+
identity = x
|
104 |
+
|
105 |
+
out = self.conv1(x)
|
106 |
+
out = self.bn1(out)
|
107 |
+
out = self.relu(out)
|
108 |
+
|
109 |
+
out = self.conv2(out)
|
110 |
+
out = self.bn2(out)
|
111 |
+
out = self.relu(out)
|
112 |
+
|
113 |
+
out = self.conv3(out)
|
114 |
+
out = self.bn3(out)
|
115 |
+
|
116 |
+
if self.downsample is not None:
|
117 |
+
identity = self.downsample(x)
|
118 |
+
|
119 |
+
out += identity
|
120 |
+
out = self.relu(out)
|
121 |
+
|
122 |
+
return out
|
123 |
+
def initialize(self):
|
124 |
+
weight_init(self)
|
125 |
+
|
126 |
+
def weight_init(module):
|
127 |
+
for n, m in module.named_children():
|
128 |
+
# print('initialize: '+n)
|
129 |
+
if isinstance(m, nn.Conv2d):
|
130 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
|
131 |
+
if m.bias is not None:
|
132 |
+
nn.init.zeros_(m.bias)
|
133 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d)):
|
134 |
+
nn.init.ones_(m.weight)
|
135 |
+
if m.bias is not None:
|
136 |
+
nn.init.zeros_(m.bias)
|
137 |
+
elif isinstance(m, nn.Linear):
|
138 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
|
139 |
+
if m.bias is not None:
|
140 |
+
nn.init.zeros_(m.bias)
|
141 |
+
elif isinstance(m, nn.Sequential):
|
142 |
+
weight_init(m)
|
143 |
+
elif isinstance(m, (nn.ReLU,nn.AdaptiveAvgPool2d,nn.Softmax,nn.MaxPool2d)):
|
144 |
+
pass
|
145 |
+
else:
|
146 |
+
m.initialize()
|
147 |
+
|
148 |
+
|
149 |
+
class ResNet(nn.Module):
|
150 |
+
|
151 |
+
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
|
152 |
+
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
153 |
+
norm_layer=None):
|
154 |
+
super(ResNet, self).__init__()
|
155 |
+
if norm_layer is None:
|
156 |
+
norm_layer = nn.BatchNorm2d
|
157 |
+
self._norm_layer = norm_layer
|
158 |
+
|
159 |
+
self.inplanes = 64
|
160 |
+
self.dilation = 1
|
161 |
+
if replace_stride_with_dilation is None:
|
162 |
+
# each element in the tuple indicates if we should replace
|
163 |
+
# the 2x2 stride with a dilated convolution instead
|
164 |
+
replace_stride_with_dilation = [False, False, False]
|
165 |
+
if len(replace_stride_with_dilation) != 3:
|
166 |
+
raise ValueError("replace_stride_with_dilation should be None "
|
167 |
+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
168 |
+
self.groups = groups
|
169 |
+
self.base_width = width_per_group
|
170 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
|
171 |
+
bias=False)
|
172 |
+
self.bn1 = norm_layer(self.inplanes)
|
173 |
+
self.relu = nn.ReLU(inplace=True)
|
174 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
175 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
176 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
|
177 |
+
dilate=replace_stride_with_dilation[0])
|
178 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
179 |
+
dilate=replace_stride_with_dilation[1])
|
180 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
181 |
+
dilate=replace_stride_with_dilation[2])
|
182 |
+
|
183 |
+
for m in self.modules():
|
184 |
+
if isinstance(m, nn.Conv2d):
|
185 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
186 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
187 |
+
nn.init.constant_(m.weight, 1)
|
188 |
+
nn.init.constant_(m.bias, 0)
|
189 |
+
|
190 |
+
# Zero-initialize the last BN in each residual branch,
|
191 |
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
192 |
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
193 |
+
if zero_init_residual:
|
194 |
+
for m in self.modules():
|
195 |
+
if isinstance(m, Bottleneck):
|
196 |
+
nn.init.constant_(m.bn3.weight, 0)
|
197 |
+
elif isinstance(m, BasicBlock):
|
198 |
+
nn.init.constant_(m.bn2.weight, 0)
|
199 |
+
|
200 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
201 |
+
norm_layer = self._norm_layer
|
202 |
+
downsample = None
|
203 |
+
previous_dilation = self.dilation
|
204 |
+
if dilate:
|
205 |
+
self.dilation *= stride
|
206 |
+
stride = 1
|
207 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
208 |
+
downsample = nn.Sequential(
|
209 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
210 |
+
norm_layer(planes * block.expansion),
|
211 |
+
)
|
212 |
+
|
213 |
+
layers = []
|
214 |
+
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
215 |
+
self.base_width, previous_dilation, norm_layer))
|
216 |
+
self.inplanes = planes * block.expansion
|
217 |
+
for _ in range(1, blocks):
|
218 |
+
layers.append(block(self.inplanes, planes, groups=self.groups,
|
219 |
+
base_width=self.base_width, dilation=self.dilation,
|
220 |
+
norm_layer=norm_layer))
|
221 |
+
|
222 |
+
return nn.Sequential(*layers)
|
223 |
+
|
224 |
+
def forward(self, x):
|
225 |
+
out1 = F.relu(self.bn1(self.conv1(x)),inplace=True)
|
226 |
+
out1 = self.maxpool(out1)
|
227 |
+
out2 = self.layer1(out1)
|
228 |
+
out3 = self.layer2(out2)
|
229 |
+
out4 = self.layer3(out3)
|
230 |
+
out5 = self.layer4(out4)
|
231 |
+
return out2, out3, out4, out5
|
232 |
+
def initialize(self):
|
233 |
+
weight_init(self)
|
234 |
+
|
235 |
+
|
236 |
+
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
|
237 |
+
model = ResNet(block, layers, **kwargs)
|
238 |
+
|
239 |
+
return model
|
240 |
+
|
241 |
+
|
242 |
+
def resnet18(pretrained=False, progress=True, **kwargs):
|
243 |
+
r"""ResNet-18 model from
|
244 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
245 |
+
|
246 |
+
Args:
|
247 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
248 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
249 |
+
"""
|
250 |
+
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
|
251 |
+
**kwargs)
|
252 |
+
|
253 |
+
|
254 |
+
def resnet34(pretrained=False, progress=True, **kwargs):
|
255 |
+
r"""ResNet-34 model from
|
256 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
257 |
+
|
258 |
+
Args:
|
259 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
260 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
261 |
+
"""
|
262 |
+
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
|
263 |
+
**kwargs)
|
264 |
+
|
265 |
+
|
266 |
+
def resnet50(pretrained=False, progress=True, **kwargs):
|
267 |
+
r"""ResNet-50 model from
|
268 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
269 |
+
|
270 |
+
Args:
|
271 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
272 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
273 |
+
"""
|
274 |
+
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
|
275 |
+
**kwargs)
|
276 |
+
|
277 |
+
|
278 |
+
def resnet101(pretrained=False, progress=True, **kwargs):
|
279 |
+
r"""ResNet-101 model from
|
280 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
281 |
+
|
282 |
+
Args:
|
283 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
284 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
285 |
+
"""
|
286 |
+
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
|
287 |
+
**kwargs)
|
288 |
+
|
289 |
+
|
290 |
+
def resnet152(pretrained=False, progress=True, **kwargs):
|
291 |
+
r"""ResNet-152 model from
|
292 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
293 |
+
|
294 |
+
Args:
|
295 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
296 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
297 |
+
"""
|
298 |
+
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
|
299 |
+
**kwargs)
|
300 |
+
|
301 |
+
|
302 |
+
def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
|
303 |
+
r"""ResNeXt-50 32x4d model from
|
304 |
+
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
305 |
+
|
306 |
+
Args:
|
307 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
308 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
309 |
+
"""
|
310 |
+
kwargs['groups'] = 32
|
311 |
+
kwargs['width_per_group'] = 4
|
312 |
+
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
|
313 |
+
pretrained, progress, **kwargs)
|
314 |
+
|
315 |
+
|
316 |
+
def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
|
317 |
+
r"""ResNeXt-101 32x8d model from
|
318 |
+
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
319 |
+
|
320 |
+
Args:
|
321 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
322 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
323 |
+
"""
|
324 |
+
kwargs['groups'] = 32
|
325 |
+
kwargs['width_per_group'] = 8
|
326 |
+
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
|
327 |
+
pretrained, progress, **kwargs)
|
328 |
+
|
329 |
+
|
330 |
+
def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
|
331 |
+
r"""Wide ResNet-50-2 model from
|
332 |
+
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
333 |
+
|
334 |
+
The model is the same as ResNet except for the bottleneck number of channels
|
335 |
+
which is twice larger in every block. The number of channels in outer 1x1
|
336 |
+
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
337 |
+
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
338 |
+
|
339 |
+
Args:
|
340 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
341 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
342 |
+
"""
|
343 |
+
kwargs['width_per_group'] = 64 * 2
|
344 |
+
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
|
345 |
+
pretrained, progress, **kwargs)
|
346 |
+
|
347 |
+
|
348 |
+
def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
|
349 |
+
r"""Wide ResNet-101-2 model from
|
350 |
+
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
351 |
+
|
352 |
+
The model is the same as ResNet except for the bottleneck number of channels
|
353 |
+
which is twice larger in every block. The number of channels in outer 1x1
|
354 |
+
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
355 |
+
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
356 |
+
|
357 |
+
Args:
|
358 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
359 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
360 |
+
"""
|
361 |
+
kwargs['width_per_group'] = 64 * 2
|
362 |
+
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
|
363 |
+
pretrained, progress, **kwargs)
|
sod/Swin.py
ADDED
@@ -0,0 +1,578 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.utils.checkpoint as checkpoint
|
4 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
5 |
+
|
6 |
+
class Mlp(nn.Module):
|
7 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
8 |
+
super().__init__()
|
9 |
+
out_features = out_features or in_features
|
10 |
+
hidden_features = hidden_features or in_features
|
11 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
12 |
+
self.act = act_layer()
|
13 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
14 |
+
self.drop = nn.Dropout(drop)
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
x = self.fc1(x)
|
18 |
+
x = self.act(x)
|
19 |
+
x = self.drop(x)
|
20 |
+
x = self.fc2(x)
|
21 |
+
x = self.drop(x)
|
22 |
+
return x
|
23 |
+
|
24 |
+
|
25 |
+
def window_partition(x, window_size):
|
26 |
+
"""
|
27 |
+
Args:
|
28 |
+
x: (B, H, W, C)
|
29 |
+
window_size (int): window size
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
windows: (num_windows*B, window_size, window_size, C)
|
33 |
+
"""
|
34 |
+
B, H, W, C = x.shape
|
35 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
36 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
37 |
+
return windows
|
38 |
+
|
39 |
+
|
40 |
+
def window_reverse(windows, window_size, H, W):
|
41 |
+
"""
|
42 |
+
Args:
|
43 |
+
windows: (num_windows*B, window_size, window_size, C)
|
44 |
+
window_size (int): Window size
|
45 |
+
H (int): Height of image
|
46 |
+
W (int): Width of image
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
x: (B, H, W, C)
|
50 |
+
"""
|
51 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
52 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
53 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
54 |
+
return x
|
55 |
+
|
56 |
+
|
57 |
+
class WindowAttention(nn.Module):
|
58 |
+
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
59 |
+
It supports both of shifted and non-shifted window.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
dim (int): Number of input channels.
|
63 |
+
window_size (tuple[int]): The height and width of the window.
|
64 |
+
num_heads (int): Number of attention heads.
|
65 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
66 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
67 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
68 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
69 |
+
"""
|
70 |
+
|
71 |
+
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
72 |
+
|
73 |
+
super().__init__()
|
74 |
+
self.dim = dim
|
75 |
+
self.window_size = window_size # Wh, Ww
|
76 |
+
self.num_heads = num_heads
|
77 |
+
head_dim = dim // num_heads
|
78 |
+
self.scale = qk_scale or head_dim ** -0.5
|
79 |
+
|
80 |
+
# define a parameter table of relative position bias
|
81 |
+
self.relative_position_bias_table = nn.Parameter(
|
82 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
83 |
+
|
84 |
+
# get pair-wise relative position index for each token inside the window
|
85 |
+
coords_h = torch.arange(self.window_size[0])
|
86 |
+
coords_w = torch.arange(self.window_size[1])
|
87 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
88 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
89 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
90 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
91 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
92 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
93 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
94 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
95 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
96 |
+
|
97 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
98 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
99 |
+
self.proj = nn.Linear(dim, dim)
|
100 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
101 |
+
|
102 |
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
103 |
+
self.softmax = nn.Softmax(dim=-1)
|
104 |
+
|
105 |
+
def forward(self, x, mask=None):
|
106 |
+
"""
|
107 |
+
Args:
|
108 |
+
x: input features with shape of (num_windows*B, N, C)
|
109 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
110 |
+
"""
|
111 |
+
B_, N, C = x.shape
|
112 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
113 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
114 |
+
|
115 |
+
q = q * self.scale
|
116 |
+
attn = (q @ k.transpose(-2, -1))
|
117 |
+
|
118 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
119 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
120 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
121 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
122 |
+
|
123 |
+
if mask is not None:
|
124 |
+
nW = mask.shape[0]
|
125 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
126 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
127 |
+
attn = self.softmax(attn)
|
128 |
+
else:
|
129 |
+
attn = self.softmax(attn)
|
130 |
+
|
131 |
+
attn = self.attn_drop(attn)
|
132 |
+
|
133 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
134 |
+
x = self.proj(x)
|
135 |
+
x = self.proj_drop(x)
|
136 |
+
return x
|
137 |
+
|
138 |
+
def extra_repr(self) -> str:
|
139 |
+
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
|
140 |
+
|
141 |
+
def flops(self, N):
|
142 |
+
# calculate flops for 1 window with token length of N
|
143 |
+
flops = 0
|
144 |
+
# qkv = self.qkv(x)
|
145 |
+
flops += N * self.dim * 3 * self.dim
|
146 |
+
# attn = (q @ k.transpose(-2, -1))
|
147 |
+
flops += self.num_heads * N * (self.dim // self.num_heads) * N
|
148 |
+
# x = (attn @ v)
|
149 |
+
flops += self.num_heads * N * N * (self.dim // self.num_heads)
|
150 |
+
# x = self.proj(x)
|
151 |
+
flops += N * self.dim * self.dim
|
152 |
+
return flops
|
153 |
+
|
154 |
+
|
155 |
+
class SwinTransformerBlock(nn.Module):
|
156 |
+
r""" Swin Transformer Block.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
dim (int): Number of input channels.
|
160 |
+
input_resolution (tuple[int]): Input resulotion.
|
161 |
+
num_heads (int): Number of attention heads.
|
162 |
+
window_size (int): Window size.
|
163 |
+
shift_size (int): Shift size for SW-MSA.
|
164 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
165 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
166 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
167 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
168 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
169 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
170 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
171 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
172 |
+
"""
|
173 |
+
|
174 |
+
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
|
175 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
176 |
+
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
177 |
+
super().__init__()
|
178 |
+
self.dim = dim
|
179 |
+
self.input_resolution = input_resolution
|
180 |
+
self.num_heads = num_heads
|
181 |
+
self.window_size = window_size
|
182 |
+
self.shift_size = shift_size
|
183 |
+
self.mlp_ratio = mlp_ratio
|
184 |
+
if min(self.input_resolution) <= self.window_size:
|
185 |
+
# if window size is larger than input resolution, we don't partition windows
|
186 |
+
self.shift_size = 0
|
187 |
+
self.window_size = min(self.input_resolution)
|
188 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
189 |
+
|
190 |
+
self.norm1 = norm_layer(dim)
|
191 |
+
self.attn = WindowAttention(
|
192 |
+
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
193 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
194 |
+
|
195 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
196 |
+
self.norm2 = norm_layer(dim)
|
197 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
198 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
199 |
+
|
200 |
+
if self.shift_size > 0:
|
201 |
+
# calculate attention mask for SW-MSA
|
202 |
+
H, W = self.input_resolution
|
203 |
+
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
204 |
+
h_slices = (slice(0, -self.window_size),
|
205 |
+
slice(-self.window_size, -self.shift_size),
|
206 |
+
slice(-self.shift_size, None))
|
207 |
+
w_slices = (slice(0, -self.window_size),
|
208 |
+
slice(-self.window_size, -self.shift_size),
|
209 |
+
slice(-self.shift_size, None))
|
210 |
+
cnt = 0
|
211 |
+
for h in h_slices:
|
212 |
+
for w in w_slices:
|
213 |
+
img_mask[:, h, w, :] = cnt
|
214 |
+
cnt += 1
|
215 |
+
|
216 |
+
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
217 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
218 |
+
atten_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
219 |
+
atten_mask = atten_mask.masked_fill(atten_mask != 0, float(-100.0)).masked_fill(atten_mask == 0, float(0.0))
|
220 |
+
else:
|
221 |
+
atten_mask = None
|
222 |
+
|
223 |
+
self.register_buffer("atten_mask", atten_mask)
|
224 |
+
|
225 |
+
def forward(self, x):
|
226 |
+
H, W = self.input_resolution
|
227 |
+
B, L, C = x.shape
|
228 |
+
assert L == H * W, "input feature has wrong size"
|
229 |
+
|
230 |
+
shortcut = x
|
231 |
+
x = self.norm1(x)
|
232 |
+
x = x.view(B, H, W, C)
|
233 |
+
|
234 |
+
# cyclic shift
|
235 |
+
if self.shift_size > 0:
|
236 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
237 |
+
else:
|
238 |
+
shifted_x = x
|
239 |
+
|
240 |
+
# partition windows
|
241 |
+
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
242 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
243 |
+
|
244 |
+
# W-MSA/SW-MSA
|
245 |
+
attn_windows = self.attn(x_windows, mask=self.atten_mask) # nW*B, window_size*window_size, C
|
246 |
+
|
247 |
+
# merge windows
|
248 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
249 |
+
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
250 |
+
|
251 |
+
# reverse cyclic shift
|
252 |
+
if self.shift_size > 0:
|
253 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
254 |
+
else:
|
255 |
+
x = shifted_x
|
256 |
+
x = x.view(B, H * W, C)
|
257 |
+
|
258 |
+
# FFN
|
259 |
+
x = shortcut + self.drop_path(x)
|
260 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
261 |
+
|
262 |
+
return x
|
263 |
+
|
264 |
+
def extra_repr(self) -> str:
|
265 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
|
266 |
+
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
|
267 |
+
|
268 |
+
def flops(self):
|
269 |
+
flops = 0
|
270 |
+
H, W = self.input_resolution
|
271 |
+
# norm1
|
272 |
+
flops += self.dim * H * W
|
273 |
+
# W-MSA/SW-MSA
|
274 |
+
nW = H * W / self.window_size / self.window_size
|
275 |
+
flops += nW * self.attn.flops(self.window_size * self.window_size)
|
276 |
+
# mlp
|
277 |
+
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
|
278 |
+
# norm2
|
279 |
+
flops += self.dim * H * W
|
280 |
+
return flops
|
281 |
+
|
282 |
+
|
283 |
+
class PatchMerging(nn.Module):
|
284 |
+
r""" Patch Merging Layer.
|
285 |
+
|
286 |
+
Args:
|
287 |
+
input_resolution (tuple[int]): Resolution of input feature.
|
288 |
+
dim (int): Number of input channels.
|
289 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
290 |
+
"""
|
291 |
+
|
292 |
+
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
|
293 |
+
super().__init__()
|
294 |
+
self.input_resolution = input_resolution
|
295 |
+
self.dim = dim
|
296 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
297 |
+
self.norm = norm_layer(4 * dim)
|
298 |
+
|
299 |
+
def forward(self, x):
|
300 |
+
"""
|
301 |
+
x: B, H*W, C
|
302 |
+
"""
|
303 |
+
H, W = self.input_resolution
|
304 |
+
B, L, C = x.shape
|
305 |
+
assert L == H * W, "input feature has wrong size"
|
306 |
+
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
307 |
+
|
308 |
+
x = x.view(B, H, W, C)
|
309 |
+
|
310 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
311 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
312 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
313 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
314 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
315 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
316 |
+
|
317 |
+
x = self.norm(x)
|
318 |
+
x = self.reduction(x)
|
319 |
+
|
320 |
+
return x
|
321 |
+
|
322 |
+
def extra_repr(self) -> str:
|
323 |
+
return f"input_resolution={self.input_resolution}, dim={self.dim}"
|
324 |
+
|
325 |
+
def flops(self):
|
326 |
+
H, W = self.input_resolution
|
327 |
+
flops = H * W * self.dim
|
328 |
+
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
329 |
+
return flops
|
330 |
+
|
331 |
+
|
332 |
+
class BasicLayer(nn.Module):
|
333 |
+
""" A basic Swin Transformer layer for one stage.
|
334 |
+
|
335 |
+
Args:
|
336 |
+
dim (int): Number of input channels.
|
337 |
+
input_resolution (tuple[int]): Input resolution.
|
338 |
+
depth (int): Number of blocks.
|
339 |
+
num_heads (int): Number of attention heads.
|
340 |
+
window_size (int): Local window size.
|
341 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
342 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
343 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
344 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
345 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
346 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
347 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
348 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
349 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
350 |
+
"""
|
351 |
+
|
352 |
+
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
353 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
354 |
+
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
|
355 |
+
|
356 |
+
super().__init__()
|
357 |
+
self.dim = dim
|
358 |
+
self.input_resolution = input_resolution
|
359 |
+
self.depth = depth
|
360 |
+
self.use_checkpoint = use_checkpoint
|
361 |
+
|
362 |
+
# build blocks
|
363 |
+
self.blocks = nn.ModuleList([
|
364 |
+
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
|
365 |
+
num_heads=num_heads, window_size=window_size,
|
366 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
367 |
+
mlp_ratio=mlp_ratio,
|
368 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
369 |
+
drop=drop, attn_drop=attn_drop,
|
370 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
371 |
+
norm_layer=norm_layer)
|
372 |
+
for i in range(depth)])
|
373 |
+
|
374 |
+
# patch merging layer
|
375 |
+
if downsample is not None:
|
376 |
+
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
|
377 |
+
else:
|
378 |
+
self.downsample = None
|
379 |
+
|
380 |
+
def forward(self, x):
|
381 |
+
for blk in self.blocks:
|
382 |
+
if self.use_checkpoint:
|
383 |
+
x = checkpoint.checkpoint(blk, x)
|
384 |
+
else:
|
385 |
+
x = blk(x)
|
386 |
+
if self.downsample is not None:
|
387 |
+
x = self.downsample(x)
|
388 |
+
return x
|
389 |
+
|
390 |
+
def extra_repr(self) -> str:
|
391 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
392 |
+
|
393 |
+
def flops(self):
|
394 |
+
flops = 0
|
395 |
+
for blk in self.blocks:
|
396 |
+
flops += blk.flops()
|
397 |
+
if self.downsample is not None:
|
398 |
+
flops += self.downsample.flops()
|
399 |
+
return flops
|
400 |
+
|
401 |
+
|
402 |
+
class PatchEmbed(nn.Module):
|
403 |
+
r""" Image to Patch Embedding
|
404 |
+
|
405 |
+
Args:
|
406 |
+
img_size (int): Image size. Default: 224.
|
407 |
+
patch_size (int): Patch token size. Default: 4.
|
408 |
+
in_chans (int): Number of input image channels. Default: 3.
|
409 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
410 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
411 |
+
"""
|
412 |
+
|
413 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
414 |
+
super().__init__()
|
415 |
+
img_size = to_2tuple(img_size)
|
416 |
+
patch_size = to_2tuple(patch_size)
|
417 |
+
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
418 |
+
self.img_size = img_size
|
419 |
+
self.patch_size = patch_size
|
420 |
+
self.patches_resolution = patches_resolution
|
421 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
422 |
+
|
423 |
+
self.in_chans = in_chans
|
424 |
+
self.embed_dim = embed_dim
|
425 |
+
|
426 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
427 |
+
if norm_layer is not None:
|
428 |
+
self.norm = norm_layer(embed_dim)
|
429 |
+
else:
|
430 |
+
self.norm = None
|
431 |
+
|
432 |
+
def forward(self, x):
|
433 |
+
B, C, H, W = x.shape
|
434 |
+
# FIXME look at relaxing size constraints
|
435 |
+
assert H == self.img_size[0] and W == self.img_size[1], \
|
436 |
+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
437 |
+
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
|
438 |
+
if self.norm is not None:
|
439 |
+
x = self.norm(x)
|
440 |
+
return x
|
441 |
+
|
442 |
+
def flops(self):
|
443 |
+
Ho, Wo = self.patches_resolution
|
444 |
+
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
445 |
+
if self.norm is not None:
|
446 |
+
flops += Ho * Wo * self.embed_dim
|
447 |
+
return flops
|
448 |
+
|
449 |
+
|
450 |
+
class Swintransformer(nn.Module):
|
451 |
+
r""" Swin Transformer
|
452 |
+
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
453 |
+
https://arxiv.org/pdf/2103.14030
|
454 |
+
|
455 |
+
Args:
|
456 |
+
img_size (int | tuple(int)): Input image size. Default 224
|
457 |
+
patch_size (int | tuple(int)): Patch size. Default: 4
|
458 |
+
in_chans (int): Number of input image channels. Default: 3
|
459 |
+
num_classes (int): Number of classes for classification head. Default: 1000
|
460 |
+
embed_dim (int): Patch embedding dimension. Default: 96
|
461 |
+
depths (tuple(int)): Depth of each Swin Transformer layer.
|
462 |
+
num_heads (tuple(int)): Number of attention heads in different layers.
|
463 |
+
window_size (int): Window size. Default: 7
|
464 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
465 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
466 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
|
467 |
+
drop_rate (float): Dropout rate. Default: 0
|
468 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0
|
469 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
470 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
471 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
472 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
473 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
|
474 |
+
"""
|
475 |
+
|
476 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
|
477 |
+
embed_dim=128, depths=[2, 2, 18,2], num_heads=[4, 8, 16, 32],
|
478 |
+
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
479 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.5,
|
480 |
+
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
481 |
+
use_checkpoint=False, **kwargs):
|
482 |
+
super().__init__()
|
483 |
+
|
484 |
+
self.num_classes = num_classes
|
485 |
+
self.num_layers = len(depths)
|
486 |
+
self.embed_dim = embed_dim
|
487 |
+
self.ape = ape
|
488 |
+
self.patch_norm = patch_norm
|
489 |
+
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
|
490 |
+
self.mlp_ratio = mlp_ratio
|
491 |
+
|
492 |
+
# split image into non-overlapping patches
|
493 |
+
self.patch_embed = PatchEmbed(
|
494 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
|
495 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
496 |
+
num_patches = self.patch_embed.num_patches
|
497 |
+
patches_resolution = self.patch_embed.patches_resolution
|
498 |
+
self.patches_resolution = patches_resolution
|
499 |
+
|
500 |
+
# absolute position embedding
|
501 |
+
if self.ape:
|
502 |
+
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
503 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
504 |
+
|
505 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
506 |
+
|
507 |
+
# stochastic depth
|
508 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
509 |
+
|
510 |
+
# build layers
|
511 |
+
self.layers = nn.ModuleList()
|
512 |
+
for i_layer in range(self.num_layers-1):
|
513 |
+
layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
|
514 |
+
input_resolution=(patches_resolution[0] // (2 ** i_layer ),
|
515 |
+
patches_resolution[1] // (2 ** i_layer )),
|
516 |
+
depth=depths[i_layer],
|
517 |
+
num_heads=num_heads[i_layer],
|
518 |
+
window_size=window_size,
|
519 |
+
mlp_ratio=self.mlp_ratio,
|
520 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
521 |
+
drop=drop_rate, attn_drop=attn_drop_rate,
|
522 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
523 |
+
norm_layer=norm_layer,
|
524 |
+
downsample=PatchMerging if (i_layer < self.num_layers - 2) else None,
|
525 |
+
use_checkpoint=use_checkpoint)
|
526 |
+
self.layers.append(layer)
|
527 |
+
|
528 |
+
self.norm = norm_layer(self.num_features)
|
529 |
+
self.avgpool = nn.AdaptiveAvgPool1d(1)
|
530 |
+
|
531 |
+
self.apply(self._init_weights)
|
532 |
+
|
533 |
+
def _init_weights(self, m):
|
534 |
+
if isinstance(m, nn.Linear):
|
535 |
+
trunc_normal_(m.weight, std=.02)
|
536 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
537 |
+
nn.init.constant_(m.bias, 0)
|
538 |
+
elif isinstance(m, nn.LayerNorm):
|
539 |
+
nn.init.constant_(m.bias, 0)
|
540 |
+
nn.init.constant_(m.weight, 1.0)
|
541 |
+
|
542 |
+
@torch.jit.ignore
|
543 |
+
def no_weight_decay(self):
|
544 |
+
return {'absolute_pos_embed'}
|
545 |
+
|
546 |
+
@torch.jit.ignore
|
547 |
+
def no_weight_decay_keywords(self):
|
548 |
+
return {'relative_position_bias_table'}
|
549 |
+
|
550 |
+
def forward_features(self, x):
|
551 |
+
b,c,h,w = x.shape
|
552 |
+
x = self.patch_embed(x)
|
553 |
+
|
554 |
+
if self.ape:
|
555 |
+
x = x + self.absolute_pos_embed
|
556 |
+
x = self.pos_drop(x)
|
557 |
+
s = []
|
558 |
+
s.append(x.view(b, int((x.shape[1])**0.5),int((x.shape[1])**(0.5)), -1).permute(0, 3, 1, 2).contiguous())
|
559 |
+
for layer in self.layers:
|
560 |
+
x = layer(x)
|
561 |
+
s.append(x.view(b, int((x.shape[1])**0.5),int((x.shape[1])**(0.5)), -1).permute(0, 3, 1, 2).contiguous())
|
562 |
+
|
563 |
+
# x = self.norm(x) # B L C
|
564 |
+
# x = self.avgpool(x.transpose(1, 2)) # B C 1
|
565 |
+
# x = torch.flatten(x, 1)
|
566 |
+
return s
|
567 |
+
|
568 |
+
def forward(self, x):
|
569 |
+
x = self.forward_features(x)
|
570 |
+
# x = self.head(x)
|
571 |
+
return x
|
572 |
+
|
573 |
+
def flops(self):
|
574 |
+
flops = 0
|
575 |
+
flops += self.patch_embed.flops()
|
576 |
+
for i, layer in enumerate(self.layers):
|
577 |
+
flops += layer.flops()
|
578 |
+
return flops
|
sod/configs/prediction/default.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
indir: no # to be overriden in CLI
|
2 |
+
outdir: no # to be overriden in CLI
|
3 |
+
|
4 |
+
model:
|
5 |
+
path: no # to be overriden in CLI
|
6 |
+
checkpoint: best.ckpt
|
7 |
+
|
8 |
+
dataset:
|
9 |
+
kind: default
|
10 |
+
img_suffix: .png
|
11 |
+
pad_out_to_modulo: 8
|
12 |
+
|
13 |
+
device: cuda
|
14 |
+
out_key: inpainted
|
sod/infer_model.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
import sys
|
8 |
+
sys.path.insert(0, '../')
|
9 |
+
sys.dont_write_bytecode = True
|
10 |
+
from .PGNet import PGNet
|
11 |
+
|
12 |
+
class Normalize(object):
|
13 |
+
def __init__(self, mean, std):
|
14 |
+
self.mean = mean
|
15 |
+
self.std = std
|
16 |
+
|
17 |
+
def __call__(self, image):
|
18 |
+
image = (image - self.mean)/self.std
|
19 |
+
return image
|
20 |
+
|
21 |
+
class Config(object):
|
22 |
+
def __init__(self, **kwargs):
|
23 |
+
self.kwargs = kwargs
|
24 |
+
self.mean = np.array([[[124.55, 118.90, 102.94]]])
|
25 |
+
self.std = np.array([[[ 56.77, 55.97, 57.50]]])
|
26 |
+
print('\nParameters...')
|
27 |
+
for k, v in self.kwargs.items():
|
28 |
+
print('%-10s: %s'%(k, v))
|
29 |
+
|
30 |
+
def __getattr__(self, name):
|
31 |
+
if name in self.kwargs:
|
32 |
+
return self.kwargs[name]
|
33 |
+
else:
|
34 |
+
return None
|
35 |
+
|
36 |
+
class IVModel():
|
37 |
+
def __init__(self, device=torch.device('cuda:0')):
|
38 |
+
super(IVModel, self).__init__()
|
39 |
+
self.device = device
|
40 |
+
checkpoint_path = 'sod/weights/PGNet_DUT+HR-model-31.pth'
|
41 |
+
self.cfg = Config(snapshot=checkpoint_path, mode='test')
|
42 |
+
if not os.path.exists(checkpoint_path):
|
43 |
+
print('未找到模型文件!')
|
44 |
+
self.net = PGNet(self.cfg)
|
45 |
+
self.net.train(False)
|
46 |
+
self.net.to(device)
|
47 |
+
self.normalize = Normalize(mean=self.cfg.mean, std=self.cfg.std)
|
48 |
+
|
49 |
+
self.__first_forward__()
|
50 |
+
|
51 |
+
|
52 |
+
def __first_forward__(self, input_size=(2048, 2048, 3)):
|
53 |
+
# 调用forward()严格控制最大显存
|
54 |
+
print('initialize Sod Model...')
|
55 |
+
_ = self.forward(np.random.rand(*input_size) * 255, None)
|
56 |
+
print('initialize Complete!')
|
57 |
+
|
58 |
+
def __resize_tensor__(self, image, max_size=1024):
|
59 |
+
h, w = image.size()[2:]
|
60 |
+
if max(h, w) > max_size:
|
61 |
+
if h < w:
|
62 |
+
h, w = int(max_size * h / w)//8*8, max_size
|
63 |
+
else:
|
64 |
+
h, w = max_size, int(max_size * w / h)//8*8
|
65 |
+
image = F.interpolate(image, (h, w), mode='area')
|
66 |
+
return image
|
67 |
+
|
68 |
+
def input_preprocess_tensor(self, img):
|
69 |
+
img = self.normalize(img)
|
70 |
+
img_t = torch.from_numpy(img.astype(np.float32)) # .to(self.device)
|
71 |
+
img_t = img_t.permute(2, 0, 1).unsqueeze(0)
|
72 |
+
img_t = self.__resize_tensor__(img_t).to(self.device) # 为了控制最大显存容量
|
73 |
+
return img_t
|
74 |
+
|
75 |
+
def forward(self, img, json_data):
|
76 |
+
img_t = self.input_preprocess_tensor(img)
|
77 |
+
shape = [torch.as_tensor([img_t.shape[2]]), torch.as_tensor([img_t.shape[3]])]
|
78 |
+
h, w = img_t.shape[2], img_t.shape[3]
|
79 |
+
img_t_temp = F.interpolate(img_t, (1024, 1024), mode='area')
|
80 |
+
with torch.no_grad():
|
81 |
+
res = self.net(img_t_temp, shape=shape)
|
82 |
+
res = F.interpolate(res[0],size=shape, mode='bilinear')
|
83 |
+
res = torch.sigmoid(res)
|
84 |
+
# print(res.shape, img_t.shape, res.expand_as(img_t).shape)
|
85 |
+
res = torch.cat([img_t, res.expand_as(img_t)], dim=3)
|
86 |
+
res = (res[0].permute(1,2,0)).cpu().numpy()
|
87 |
+
res[:,:w,:] = res[:,:w,:] * self.cfg.std + self.cfg.mean
|
88 |
+
res[:,w:,:] = res[:,w:,:] * 255
|
89 |
+
return res
|