wenpeng commited on
Commit
c7813b3
·
1 Parent(s): 9f59d48

update .gitignore

Browse files
.gitattributes CHANGED
@@ -25,3 +25,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ inpaint/weights
2
+ sod/weights
3
+ **/__pycache__
4
+ flagged
5
+ **/*.zip
app.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from doctest import Example
2
+ import gradio as gr
3
+ import inpaint.infer_model as inpaint
4
+ import sod.infer_model as sod
5
+ import numpy as np
6
+ import torch
7
+ import os
8
+ # cmd = 'sh download.sh'
9
+ # os.system(cmd)
10
+
11
+ device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
12
+ inpaint_model = inpaint.IVModel(device=device)
13
+ sod_model = sod.IVModel(device=torch.device("cpu"))
14
+ def sod_inpaint(img):
15
+ img = img[:,:,::-1]
16
+ res = sod_model.forward(img,None)
17
+ res = np.uint8(res)
18
+ res = inpaint_model.forward(res,None)
19
+ res = np.uint8(res)
20
+ return res[:,:,::-1]
21
+
22
+ iface = gr.Interface(fn=sod_inpaint, inputs="image", outputs="image", examples='examples', title='显著物体消除', description='这是一个图像API,功能是自动把画面中的显著物体消除', theme='huggingface')
23
+ iface.launch(server_name='0.0.0.0', share=False)
download.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FILE_ID=1udSLeuWAZf2-uI7SI8dFEfBuyvLSpB9V
2
+ checkpoint_path='inpaint/weights.zip'
3
+ wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate "https://docs.google.com/uc?export=download&id=${FILE_ID}" -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/
4
+ /p')&id=${FILE_ID}" -O ${checkpoint_path} && rm -rf /tmp/cookies.txt
5
+ unzip $checkpoint_path
6
+
7
+ FILE_ID=1qI8-HBTz2nNSTyD9iB4YMn067P7XMKuD
8
+ checkpoint_path='sod/weights.zip'
9
+ wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate "https://docs.google.com/uc?export=download&id=${FILE_ID}" -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/
10
+ /p')&id=${FILE_ID}" -O ${checkpoint_path} && rm -rf /tmp/cookies.txt
11
+ unzip $checkpoint_path
examples/SOD001.jpg ADDED
examples/SOD003.jpeg ADDED
examples/SOD013.jpg ADDED
examples/SOD015.jpg ADDED
inpaint/configs/prediction/default.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ indir: no # to be overriden in CLI
2
+ outdir: no # to be overriden in CLI
3
+
4
+ model:
5
+ path: no # to be overriden in CLI
6
+ checkpoint: best.ckpt
7
+
8
+ dataset:
9
+ kind: default
10
+ img_suffix: .png
11
+ pad_out_to_modulo: 8
12
+
13
+ device: cuda
14
+ out_key: inpainted
inpaint/infer_model.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ import os
8
+ import sys
9
+ sys.path.append('inpaint')
10
+ os.environ['OMP_NUM_THREADS'] = '1'
11
+ os.environ['OPENBLAS_NUM_THREADS'] = '1'
12
+ os.environ['MKL_NUM_THREADS'] = '1'
13
+ os.environ['VECLIB_MAXIMUM_THREADS'] = '1'
14
+ os.environ['NUMEXPR_NUM_THREADS'] = '1'
15
+
16
+ import cv2
17
+ import numpy as np
18
+ import torch
19
+ import yaml
20
+ from omegaconf import OmegaConf
21
+ from saicinpainting.training.trainers import load_checkpoint
22
+
23
+
24
+ class IVModel():
25
+ def __init__(self, device=torch.device('cuda:0')):
26
+ super(IVModel, self).__init__()
27
+ self.device = device
28
+ conf_path = 'inpaint/configs/prediction/default.yaml'
29
+ predict_config = OmegaConf.load(conf_path)
30
+ predict_config.model.path='inpaint/weights/big-lama'
31
+ if not os.path.exists(conf_path):
32
+ print('未找到配置文件!')
33
+ train_config_path = os.path.join(predict_config.model.path, 'config.yaml')
34
+ with open(train_config_path, 'r') as f:
35
+ train_config = OmegaConf.create(yaml.safe_load(f))
36
+
37
+ train_config.training_model.predict_only = True
38
+ train_config.visualizer.kind = 'noop'
39
+
40
+ checkpoint_path = os.path.join(predict_config.model.path, 'models', predict_config.model.checkpoint)
41
+ self.model = load_checkpoint(train_config, checkpoint_path, strict=False, map_location='cpu')
42
+ self.model.freeze()
43
+ self.model.to(device)
44
+
45
+ self.__first_forward__()
46
+
47
+
48
+ def __first_forward__(self, input_size=(2048, 4096, 3)):
49
+ # 调用forward()严格控制最大显存
50
+ print('initialize Inpaint Model...')
51
+ _ = self.forward(np.random.rand(*input_size) * 255, None)
52
+ print('initialize Complete!')
53
+
54
+ def __resize_tensor__(self, image, max_size=1024, scale_factor=8):
55
+ h, w = image.size()[2:]
56
+ if max(h, w) > max_size:
57
+ if h < w:
58
+ h, w = int(max_size * h / w), max_size
59
+ else:
60
+ h, w = max_size, int(max_size * w / h)
61
+ h = h // scale_factor * scale_factor
62
+ w = w // scale_factor * scale_factor
63
+ image = F.interpolate(image, (h, w), mode='bicubic')
64
+ return image
65
+
66
+ def input_preprocess_tensor(self, img):
67
+ img_t = torch.from_numpy(img.astype(np.float32)) # .to(self.device)
68
+ img_t = img_t.permute(2, 0, 1).unsqueeze(0)
69
+ img_t = img_t / 255.
70
+ img_t_for_net = self.__resize_tensor__(img_t).to(self.device) # 为了控制最大显存容量
71
+ img_t_for_out = self.__resize_tensor__(img_t, max_size=2048).to(self.device) # 为了控制最大显存容量
72
+ return img_t_for_net, img_t_for_out
73
+
74
+ def forward(self, img, json_data):
75
+ _,w,_ = img.shape
76
+ mask = img[:,w//2:,0]
77
+ kernel = np.ones((4, 4), np.uint8)
78
+ mask = cv2.dilate(mask, kernel, iterations=5)[:,:,np.newaxis]
79
+ img = img[:,:w//2,::-1]
80
+ # print(img.shape,mask.shape)
81
+ img_t_for_net, img_t_for_out = self.input_preprocess_tensor(img)
82
+ mask_t_for_net, mask_t_for_out = self.input_preprocess_tensor(mask)
83
+ h, w = img_t_for_out.shape[2:]
84
+ # print(img_t.shape,mask_t.shape)
85
+ mask_t_for_out = (mask_t_for_out>0).int()
86
+ mask_t_for_net = (mask_t_for_net>0).int()
87
+ with torch.no_grad():
88
+ res_t = self.model(dict(image=img_t_for_net, mask=mask_t_for_net))['inpainted']
89
+ res_t = F.interpolate(res_t, (h, w), mode='bicubic')
90
+ res_t = img_t_for_out * (1 - mask_t_for_out) + res_t * mask_t_for_out
91
+ res_t = torch.clip(res_t * 255, min=0, max=255)
92
+ res = res_t.squeeze(0).permute(1, 2, 0).detach().cpu().numpy()
93
+ res = res[:, :, (2, 1, 0)].astype(np.uint8)
94
+ return res
inpaint/predict.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Example command:
4
+ # ./bin/predict.py \
5
+ # model.path=<path to checkpoint, prepared by make_checkpoint.py> \
6
+ # indir=<path to input data> \
7
+ # outdir=<where to store predicts>
8
+
9
+ import logging
10
+ import os
11
+ import sys
12
+ import traceback
13
+
14
+ from saicinpainting.evaluation.utils import move_to_device
15
+
16
+ os.environ['OMP_NUM_THREADS'] = '1'
17
+ os.environ['OPENBLAS_NUM_THREADS'] = '1'
18
+ os.environ['MKL_NUM_THREADS'] = '1'
19
+ os.environ['VECLIB_MAXIMUM_THREADS'] = '1'
20
+ os.environ['NUMEXPR_NUM_THREADS'] = '1'
21
+
22
+ import cv2
23
+ # import hydra
24
+ import numpy as np
25
+ import torch
26
+ import tqdm
27
+ import yaml
28
+ from omegaconf import OmegaConf
29
+ from torch.utils.data._utils.collate import default_collate
30
+
31
+ from saicinpainting.training.data.datasets import make_default_val_dataset
32
+ from saicinpainting.training.trainers import load_checkpoint
33
+ from saicinpainting.utils import register_debug_signal_handlers
34
+
35
+ LOGGER = logging.getLogger(__name__)
36
+
37
+
38
+ # @hydra.main(config_path='../configs/prediction', config_name='default.yaml')
39
+ def main(predict_config: OmegaConf):
40
+ try:
41
+ register_debug_signal_handlers() # kill -10 <pid> will result in traceback dumped into log
42
+
43
+ device = torch.device(predict_config.device)
44
+ print(predict_config)
45
+ train_config_path = os.path.join(predict_config.model.path, 'config.yaml')
46
+ with open(train_config_path, 'r') as f:
47
+ train_config = OmegaConf.create(yaml.safe_load(f))
48
+
49
+ train_config.training_model.predict_only = True
50
+ train_config.visualizer.kind = 'noop'
51
+
52
+ out_ext = predict_config.get('out_ext', '.png')
53
+
54
+ checkpoint_path = os.path.join(predict_config.model.path,
55
+ 'models',
56
+ predict_config.model.checkpoint)
57
+ model = load_checkpoint(train_config, checkpoint_path, strict=False, map_location='cpu')
58
+ model.freeze()
59
+ model.to(device)
60
+
61
+ if not predict_config.indir.endswith('/'):
62
+ predict_config.indir += '/'
63
+
64
+ dataset = make_default_val_dataset(predict_config.indir, **predict_config.dataset)
65
+ with torch.no_grad():
66
+ for img_i in tqdm.trange(len(dataset)):
67
+ mask_fname = dataset.mask_filenames[img_i]
68
+ cur_out_fname = os.path.join(
69
+ predict_config.outdir,
70
+ os.path.splitext(os.path.basename(mask_fname))[0] + out_ext
71
+ )
72
+ os.makedirs(os.path.dirname(cur_out_fname), exist_ok=True)
73
+
74
+ batch = move_to_device(default_collate([dataset[img_i]]), device)
75
+ batch['mask'] = (batch['mask'] > 0) * 1
76
+ # print(torch.max(batch['mask']), torch.min(batch['mask']), torch.max(batch['image']), torch.min(batch['image']))
77
+ print(batch['mask'].dtype)
78
+ batch = model(batch)
79
+ cur_res = batch[predict_config.out_key][0].permute(1, 2, 0).detach().cpu().numpy()
80
+
81
+ cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
82
+ cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
83
+ cv2.imwrite(cur_out_fname, cur_res)
84
+ except KeyboardInterrupt:
85
+ LOGGER.warning('Interrupted by user')
86
+ except Exception as ex:
87
+ LOGGER.critical(f'Prediction failed due to {ex}:\n{traceback.format_exc()}')
88
+ sys.exit(1)
89
+
90
+
91
+ if __name__ == '__main__':
92
+ base_conf = OmegaConf.load('configs/prediction/default.yaml')
93
+ base_conf.model.path='../../weights/big-lama'
94
+ base_conf.indir='/home/zwp/Temp2018/ZWP/data/example/照片补全/测试图片/带mask的图片2'
95
+ base_conf.outdir='zb_result/带mask的图片2'
96
+ main(predict_config=base_conf)
inpaint/saicinpainting/training/modules/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from saicinpainting.training.modules.ffc import FFCResNetGenerator
2
+
3
+ def make_generator(config, kind, **kwargs):
4
+ return FFCResNetGenerator(**kwargs)
5
+
6
+
7
+
inpaint/saicinpainting/training/modules/base.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from typing import Tuple, List
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from saicinpainting.training.modules.depthwise_sep_conv import DepthWiseSeperableConv
8
+ from saicinpainting.training.modules.multidilated_conv import MultidilatedConv
9
+
10
+
11
+ class BaseDiscriminator(nn.Module):
12
+ @abc.abstractmethod
13
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
14
+ """
15
+ Predict scores and get intermediate activations. Useful for feature matching loss
16
+ :return tuple (scores, list of intermediate activations)
17
+ """
18
+ raise NotImplemented()
19
+
20
+
21
+ def get_conv_block_ctor(kind='default'):
22
+ if not isinstance(kind, str):
23
+ return kind
24
+ if kind == 'default':
25
+ return nn.Conv2d
26
+ if kind == 'depthwise':
27
+ return DepthWiseSeperableConv
28
+ if kind == 'multidilated':
29
+ return MultidilatedConv
30
+ raise ValueError(f'Unknown convolutional block kind {kind}')
31
+
32
+
33
+ def get_norm_layer(kind='bn'):
34
+ if not isinstance(kind, str):
35
+ return kind
36
+ if kind == 'bn':
37
+ return nn.BatchNorm2d
38
+ if kind == 'in':
39
+ return nn.InstanceNorm2d
40
+ raise ValueError(f'Unknown norm block kind {kind}')
41
+
42
+
43
+ def get_activation(kind='tanh'):
44
+ if kind == 'tanh':
45
+ return nn.Tanh()
46
+ if kind == 'sigmoid':
47
+ return nn.Sigmoid()
48
+ if kind is False:
49
+ return nn.Identity()
50
+ raise ValueError(f'Unknown activation kind {kind}')
51
+
52
+
53
+ class SimpleMultiStepGenerator(nn.Module):
54
+ def __init__(self, steps: List[nn.Module]):
55
+ super().__init__()
56
+ self.steps = nn.ModuleList(steps)
57
+
58
+ def forward(self, x):
59
+ cur_in = x
60
+ outs = []
61
+ for step in self.steps:
62
+ cur_out = step(cur_in)
63
+ outs.append(cur_out)
64
+ cur_in = torch.cat((cur_in, cur_out), dim=1)
65
+ return torch.cat(outs[::-1], dim=1)
66
+
67
+ def deconv_factory(kind, ngf, mult, norm_layer, activation, max_features):
68
+ if kind == 'convtranspose':
69
+ return [nn.ConvTranspose2d(min(max_features, ngf * mult),
70
+ min(max_features, int(ngf * mult / 2)),
71
+ kernel_size=3, stride=2, padding=1, output_padding=1),
72
+ norm_layer(min(max_features, int(ngf * mult / 2))), activation]
73
+ elif kind == 'bilinear':
74
+ return [nn.Upsample(scale_factor=2, mode='bilinear'),
75
+ DepthWiseSeperableConv(min(max_features, ngf * mult),
76
+ min(max_features, int(ngf * mult / 2)),
77
+ kernel_size=3, stride=1, padding=1),
78
+ norm_layer(min(max_features, int(ngf * mult / 2))), activation]
79
+ else:
80
+ raise Exception(f"Invalid deconv kind: {kind}")
inpaint/saicinpainting/training/modules/depthwise_sep_conv.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class DepthWiseSeperableConv(nn.Module):
5
+ def __init__(self, in_dim, out_dim, *args, **kwargs):
6
+ super().__init__()
7
+ if 'groups' in kwargs:
8
+ # ignoring groups for Depthwise Sep Conv
9
+ del kwargs['groups']
10
+
11
+ self.depthwise = nn.Conv2d(in_dim, in_dim, *args, groups=in_dim, **kwargs)
12
+ self.pointwise = nn.Conv2d(in_dim, out_dim, kernel_size=1)
13
+
14
+ def forward(self, x):
15
+ out = self.depthwise(x)
16
+ out = self.pointwise(out)
17
+ return out
inpaint/saicinpainting/training/modules/fake_fakes.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from kornia import SamplePadding
3
+ from kornia.augmentation import RandomAffine, CenterCrop
4
+
5
+
6
+ class FakeFakesGenerator:
7
+ def __init__(self, aug_proba=0.5, img_aug_degree=30, img_aug_translate=0.2):
8
+ self.grad_aug = RandomAffine(degrees=360,
9
+ translate=0.2,
10
+ padding_mode=SamplePadding.REFLECTION,
11
+ keepdim=False,
12
+ p=1)
13
+ self.img_aug = RandomAffine(degrees=img_aug_degree,
14
+ translate=img_aug_translate,
15
+ padding_mode=SamplePadding.REFLECTION,
16
+ keepdim=True,
17
+ p=1)
18
+ self.aug_proba = aug_proba
19
+
20
+ def __call__(self, input_images, masks):
21
+ blend_masks = self._fill_masks_with_gradient(masks)
22
+ blend_target = self._make_blend_target(input_images)
23
+ result = input_images * (1 - blend_masks) + blend_target * blend_masks
24
+ return result, blend_masks
25
+
26
+ def _make_blend_target(self, input_images):
27
+ batch_size = input_images.shape[0]
28
+ permuted = input_images[torch.randperm(batch_size)]
29
+ augmented = self.img_aug(input_images)
30
+ is_aug = (torch.rand(batch_size, device=input_images.device)[:, None, None, None] < self.aug_proba).float()
31
+ result = augmented * is_aug + permuted * (1 - is_aug)
32
+ return result
33
+
34
+ def _fill_masks_with_gradient(self, masks):
35
+ batch_size, _, height, width = masks.shape
36
+ grad = torch.linspace(0, 1, steps=width * 2, device=masks.device, dtype=masks.dtype) \
37
+ .view(1, 1, 1, -1).expand(batch_size, 1, height * 2, width * 2)
38
+ grad = self.grad_aug(grad)
39
+ grad = CenterCrop((height, width))(grad)
40
+ grad *= masks
41
+
42
+ grad_for_min = grad + (1 - masks) * 10
43
+ grad -= grad_for_min.view(batch_size, -1).min(-1).values[:, None, None, None]
44
+ grad /= grad.view(batch_size, -1).max(-1).values[:, None, None, None] + 1e-6
45
+ grad.clamp_(min=0, max=1)
46
+
47
+ return grad
inpaint/saicinpainting/training/modules/ffc.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Fast Fourier Convolution NeurIPS 2020
2
+ # original implementation https://github.com/pkumivision/FFC/blob/main/model_zoo/ffc.py
3
+ # paper https://proceedings.neurips.cc/paper/2020/file/2fd5d41ec6cfab47e32164d5624269b1-Paper.pdf
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from saicinpainting.training.modules.base import get_activation
10
+ from saicinpainting.training.modules.spatial_transform import LearnableSpatialTransformWrapper
11
+ from saicinpainting.training.modules.squeeze_excitation import SELayer
12
+
13
+
14
+ class FFCSE_block(nn.Module):
15
+
16
+ def __init__(self, channels, ratio_g):
17
+ super(FFCSE_block, self).__init__()
18
+ in_cg = int(channels * ratio_g)
19
+ in_cl = channels - in_cg
20
+ r = 16
21
+
22
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
23
+ self.conv1 = nn.Conv2d(channels, channels // r,
24
+ kernel_size=1, bias=True)
25
+ self.relu1 = nn.ReLU(inplace=True)
26
+ self.conv_a2l = None if in_cl == 0 else nn.Conv2d(
27
+ channels // r, in_cl, kernel_size=1, bias=True)
28
+ self.conv_a2g = None if in_cg == 0 else nn.Conv2d(
29
+ channels // r, in_cg, kernel_size=1, bias=True)
30
+ self.sigmoid = nn.Sigmoid()
31
+
32
+ def forward(self, x):
33
+ x = x if type(x) is tuple else (x, 0)
34
+ id_l, id_g = x
35
+
36
+ x = id_l if type(id_g) is int else torch.cat([id_l, id_g], dim=1)
37
+ x = self.avgpool(x)
38
+ x = self.relu1(self.conv1(x))
39
+
40
+ x_l = 0 if self.conv_a2l is None else id_l * \
41
+ self.sigmoid(self.conv_a2l(x))
42
+ x_g = 0 if self.conv_a2g is None else id_g * \
43
+ self.sigmoid(self.conv_a2g(x))
44
+ return x_l, x_g
45
+
46
+
47
+ class FourierUnit(nn.Module):
48
+
49
+ def __init__(self, in_channels, out_channels, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear',
50
+ spectral_pos_encoding=False, use_se=False, se_kwargs=None, ffc3d=False, fft_norm='ortho'):
51
+ # bn_layer not used
52
+ super(FourierUnit, self).__init__()
53
+ self.groups = groups
54
+
55
+ self.conv_layer = torch.nn.Conv2d(in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0),
56
+ out_channels=out_channels * 2,
57
+ kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False)
58
+ self.bn = torch.nn.BatchNorm2d(out_channels * 2)
59
+ self.relu = torch.nn.ReLU(inplace=True)
60
+
61
+ # squeeze and excitation block
62
+ self.use_se = use_se
63
+ if use_se:
64
+ if se_kwargs is None:
65
+ se_kwargs = {}
66
+ self.se = SELayer(self.conv_layer.in_channels, **se_kwargs)
67
+
68
+ self.spatial_scale_factor = spatial_scale_factor
69
+ self.spatial_scale_mode = spatial_scale_mode
70
+ self.spectral_pos_encoding = spectral_pos_encoding
71
+ self.ffc3d = ffc3d
72
+ self.fft_norm = fft_norm
73
+
74
+ def forward(self, x):
75
+ batch = x.shape[0]
76
+
77
+ if self.spatial_scale_factor is not None:
78
+ orig_size = x.shape[-2:]
79
+ x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode, align_corners=False)
80
+
81
+ r_size = x.size()
82
+ # (batch, c, h, w/2+1, 2)
83
+ fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
84
+ ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
85
+ ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
86
+ ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
87
+ ffted = ffted.view((batch, -1,) + ffted.size()[3:])
88
+
89
+ if self.spectral_pos_encoding:
90
+ height, width = ffted.shape[-2:]
91
+ coords_vert = torch.linspace(0, 1, height)[None, None, :, None].expand(batch, 1, height, width).to(ffted)
92
+ coords_hor = torch.linspace(0, 1, width)[None, None, None, :].expand(batch, 1, height, width).to(ffted)
93
+ ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1)
94
+
95
+ if self.use_se:
96
+ ffted = self.se(ffted)
97
+
98
+ ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1)
99
+ ffted = self.relu(self.bn(ffted))
100
+
101
+ ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
102
+ 0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2)
103
+ ffted = torch.complex(ffted[..., 0], ffted[..., 1])
104
+
105
+ ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
106
+ output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm)
107
+
108
+ if self.spatial_scale_factor is not None:
109
+ output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False)
110
+
111
+ return output
112
+
113
+
114
+ class SpectralTransform(nn.Module):
115
+
116
+ def __init__(self, in_channels, out_channels, stride=1, groups=1, enable_lfu=True, **fu_kwargs):
117
+ # bn_layer not used
118
+ super(SpectralTransform, self).__init__()
119
+ self.enable_lfu = enable_lfu
120
+ if stride == 2:
121
+ self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
122
+ else:
123
+ self.downsample = nn.Identity()
124
+
125
+ self.stride = stride
126
+ self.conv1 = nn.Sequential(
127
+ nn.Conv2d(in_channels, out_channels //
128
+ 2, kernel_size=1, groups=groups, bias=False),
129
+ nn.BatchNorm2d(out_channels // 2),
130
+ nn.ReLU(inplace=True)
131
+ )
132
+ self.fu = FourierUnit(
133
+ out_channels // 2, out_channels // 2, groups, **fu_kwargs)
134
+ if self.enable_lfu:
135
+ self.lfu = FourierUnit(
136
+ out_channels // 2, out_channels // 2, groups)
137
+ self.conv2 = torch.nn.Conv2d(
138
+ out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False)
139
+
140
+ def forward(self, x):
141
+
142
+ x = self.downsample(x)
143
+ x = self.conv1(x)
144
+ output = self.fu(x)
145
+
146
+ if self.enable_lfu:
147
+ n, c, h, w = x.shape
148
+ split_no = 2
149
+ split_s = h // split_no
150
+ xs = torch.cat(torch.split(
151
+ x[:, :c // 4], split_s, dim=-2), dim=1).contiguous()
152
+ xs = torch.cat(torch.split(xs, split_s, dim=-1),
153
+ dim=1).contiguous()
154
+ xs = self.lfu(xs)
155
+ xs = xs.repeat(1, 1, split_no, split_no).contiguous()
156
+ else:
157
+ xs = 0
158
+
159
+ output = self.conv2(x + output + xs)
160
+
161
+ return output
162
+
163
+
164
+ class FFC(nn.Module):
165
+
166
+ def __init__(self, in_channels, out_channels, kernel_size,
167
+ ratio_gin, ratio_gout, stride=1, padding=0,
168
+ dilation=1, groups=1, bias=False, enable_lfu=True,
169
+ padding_type='reflect', gated=False, **spectral_kwargs):
170
+ super(FFC, self).__init__()
171
+
172
+ assert stride == 1 or stride == 2, "Stride should be 1 or 2."
173
+ self.stride = stride
174
+
175
+ in_cg = int(in_channels * ratio_gin)
176
+ in_cl = in_channels - in_cg
177
+ out_cg = int(out_channels * ratio_gout)
178
+ out_cl = out_channels - out_cg
179
+ #groups_g = 1 if groups == 1 else int(groups * ratio_gout)
180
+ #groups_l = 1 if groups == 1 else groups - groups_g
181
+
182
+ self.ratio_gin = ratio_gin
183
+ self.ratio_gout = ratio_gout
184
+ self.global_in_num = in_cg
185
+
186
+ module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d
187
+ self.convl2l = module(in_cl, out_cl, kernel_size,
188
+ stride, padding, dilation, groups, bias, padding_mode=padding_type)
189
+ module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d
190
+ self.convl2g = module(in_cl, out_cg, kernel_size,
191
+ stride, padding, dilation, groups, bias, padding_mode=padding_type)
192
+ module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d
193
+ self.convg2l = module(in_cg, out_cl, kernel_size,
194
+ stride, padding, dilation, groups, bias, padding_mode=padding_type)
195
+ module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform
196
+ self.convg2g = module(
197
+ in_cg, out_cg, stride, 1 if groups == 1 else groups // 2, enable_lfu, **spectral_kwargs)
198
+
199
+ self.gated = gated
200
+ module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d
201
+ self.gate = module(in_channels, 2, 1)
202
+
203
+ def forward(self, x):
204
+ x_l, x_g = x if type(x) is tuple else (x, 0)
205
+ out_xl, out_xg = 0, 0
206
+
207
+ if self.gated:
208
+ total_input_parts = [x_l]
209
+ if torch.is_tensor(x_g):
210
+ total_input_parts.append(x_g)
211
+ total_input = torch.cat(total_input_parts, dim=1)
212
+
213
+ gates = torch.sigmoid(self.gate(total_input))
214
+ g2l_gate, l2g_gate = gates.chunk(2, dim=1)
215
+ else:
216
+ g2l_gate, l2g_gate = 1, 1
217
+
218
+ if self.ratio_gout != 1:
219
+ out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate
220
+ if self.ratio_gout != 0:
221
+ out_xg = self.convl2g(x_l) * l2g_gate + self.convg2g(x_g)
222
+
223
+ return out_xl, out_xg
224
+
225
+
226
+ class FFC_BN_ACT(nn.Module):
227
+
228
+ def __init__(self, in_channels, out_channels,
229
+ kernel_size, ratio_gin, ratio_gout,
230
+ stride=1, padding=0, dilation=1, groups=1, bias=False,
231
+ norm_layer=nn.BatchNorm2d, activation_layer=nn.Identity,
232
+ padding_type='reflect',
233
+ enable_lfu=True, **kwargs):
234
+ super(FFC_BN_ACT, self).__init__()
235
+ self.ffc = FFC(in_channels, out_channels, kernel_size,
236
+ ratio_gin, ratio_gout, stride, padding, dilation,
237
+ groups, bias, enable_lfu, padding_type=padding_type, **kwargs)
238
+ lnorm = nn.Identity if ratio_gout == 1 else norm_layer
239
+ gnorm = nn.Identity if ratio_gout == 0 else norm_layer
240
+ global_channels = int(out_channels * ratio_gout)
241
+ self.bn_l = lnorm(out_channels - global_channels)
242
+ self.bn_g = gnorm(global_channels)
243
+
244
+ lact = nn.Identity if ratio_gout == 1 else activation_layer
245
+ gact = nn.Identity if ratio_gout == 0 else activation_layer
246
+ self.act_l = lact(inplace=True)
247
+ self.act_g = gact(inplace=True)
248
+
249
+ def forward(self, x):
250
+ x_l, x_g = self.ffc(x)
251
+ x_l = self.act_l(self.bn_l(x_l))
252
+ x_g = self.act_g(self.bn_g(x_g))
253
+ return x_l, x_g
254
+
255
+
256
+ class FFCResnetBlock(nn.Module):
257
+ def __init__(self, dim, padding_type, norm_layer, activation_layer=nn.ReLU, dilation=1,
258
+ spatial_transform_kwargs=None, inline=False, **conv_kwargs):
259
+ super().__init__()
260
+ self.conv1 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation,
261
+ norm_layer=norm_layer,
262
+ activation_layer=activation_layer,
263
+ padding_type=padding_type,
264
+ **conv_kwargs)
265
+ self.conv2 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation,
266
+ norm_layer=norm_layer,
267
+ activation_layer=activation_layer,
268
+ padding_type=padding_type,
269
+ **conv_kwargs)
270
+ if spatial_transform_kwargs is not None:
271
+ self.conv1 = LearnableSpatialTransformWrapper(self.conv1, **spatial_transform_kwargs)
272
+ self.conv2 = LearnableSpatialTransformWrapper(self.conv2, **spatial_transform_kwargs)
273
+ self.inline = inline
274
+
275
+ def forward(self, x):
276
+ if self.inline:
277
+ x_l, x_g = x[:, :-self.conv1.ffc.global_in_num], x[:, -self.conv1.ffc.global_in_num:]
278
+ else:
279
+ x_l, x_g = x if type(x) is tuple else (x, 0)
280
+
281
+ id_l, id_g = x_l, x_g
282
+
283
+ x_l, x_g = self.conv1((x_l, x_g))
284
+ x_l, x_g = self.conv2((x_l, x_g))
285
+
286
+ x_l, x_g = id_l + x_l, id_g + x_g
287
+ out = x_l, x_g
288
+ if self.inline:
289
+ out = torch.cat(out, dim=1)
290
+ return out
291
+
292
+
293
+ class ConcatTupleLayer(nn.Module):
294
+ def forward(self, x):
295
+ assert isinstance(x, tuple)
296
+ x_l, x_g = x
297
+ assert torch.is_tensor(x_l) or torch.is_tensor(x_g)
298
+ if not torch.is_tensor(x_g):
299
+ return x_l
300
+ return torch.cat(x, dim=1)
301
+
302
+
303
+ class FFCResNetGenerator(nn.Module):
304
+ def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
305
+ padding_type='reflect', activation_layer=nn.ReLU,
306
+ up_norm_layer=nn.BatchNorm2d, up_activation=nn.ReLU(True),
307
+ init_conv_kwargs={}, downsample_conv_kwargs={}, resnet_conv_kwargs={},
308
+ spatial_transform_layers=None, spatial_transform_kwargs={},
309
+ add_out_act=True, max_features=1024, out_ffc=False, out_ffc_kwargs={}):
310
+ assert (n_blocks >= 0)
311
+ super().__init__()
312
+
313
+ model = [nn.ReflectionPad2d(3),
314
+ FFC_BN_ACT(input_nc, ngf, kernel_size=7, padding=0, norm_layer=norm_layer,
315
+ activation_layer=activation_layer, **init_conv_kwargs)]
316
+
317
+ ### downsample
318
+ for i in range(n_downsampling):
319
+ mult = 2 ** i
320
+ if i == n_downsampling - 1:
321
+ cur_conv_kwargs = dict(downsample_conv_kwargs)
322
+ cur_conv_kwargs['ratio_gout'] = resnet_conv_kwargs.get('ratio_gin', 0)
323
+ else:
324
+ cur_conv_kwargs = downsample_conv_kwargs
325
+ model += [FFC_BN_ACT(min(max_features, ngf * mult),
326
+ min(max_features, ngf * mult * 2),
327
+ kernel_size=3, stride=2, padding=1,
328
+ norm_layer=norm_layer,
329
+ activation_layer=activation_layer,
330
+ **cur_conv_kwargs)]
331
+
332
+ mult = 2 ** n_downsampling
333
+ feats_num_bottleneck = min(max_features, ngf * mult)
334
+
335
+ ### resnet blocks
336
+ for i in range(n_blocks):
337
+ cur_resblock = FFCResnetBlock(feats_num_bottleneck, padding_type=padding_type, activation_layer=activation_layer,
338
+ norm_layer=norm_layer, **resnet_conv_kwargs)
339
+ if spatial_transform_layers is not None and i in spatial_transform_layers:
340
+ cur_resblock = LearnableSpatialTransformWrapper(cur_resblock, **spatial_transform_kwargs)
341
+ model += [cur_resblock]
342
+
343
+ model += [ConcatTupleLayer()]
344
+
345
+ ### upsample
346
+ for i in range(n_downsampling):
347
+ mult = 2 ** (n_downsampling - i)
348
+ model += [nn.ConvTranspose2d(min(max_features, ngf * mult),
349
+ min(max_features, int(ngf * mult / 2)),
350
+ kernel_size=3, stride=2, padding=1, output_padding=1),
351
+ up_norm_layer(min(max_features, int(ngf * mult / 2))),
352
+ up_activation]
353
+
354
+ if out_ffc:
355
+ model += [FFCResnetBlock(ngf, padding_type=padding_type, activation_layer=activation_layer,
356
+ norm_layer=norm_layer, inline=True, **out_ffc_kwargs)]
357
+
358
+ model += [nn.ReflectionPad2d(3),
359
+ nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
360
+ if add_out_act:
361
+ model.append(get_activation('tanh' if add_out_act is True else add_out_act))
362
+ self.model = nn.Sequential(*model)
363
+
364
+ def forward(self, input):
365
+ return self.model(input)
366
+
367
+
inpaint/saicinpainting/training/modules/multidilated_conv.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import random
4
+ from saicinpainting.training.modules.depthwise_sep_conv import DepthWiseSeperableConv
5
+
6
+ class MultidilatedConv(nn.Module):
7
+ def __init__(self, in_dim, out_dim, kernel_size, dilation_num=3, comb_mode='sum', equal_dim=True,
8
+ shared_weights=False, padding=1, min_dilation=1, shuffle_in_channels=False, use_depthwise=False, **kwargs):
9
+ super().__init__()
10
+ convs = []
11
+ self.equal_dim = equal_dim
12
+ assert comb_mode in ('cat_out', 'sum', 'cat_in', 'cat_both'), comb_mode
13
+ if comb_mode in ('cat_out', 'cat_both'):
14
+ self.cat_out = True
15
+ if equal_dim:
16
+ assert out_dim % dilation_num == 0
17
+ out_dims = [out_dim // dilation_num] * dilation_num
18
+ self.index = sum([[i + j * (out_dims[0]) for j in range(dilation_num)] for i in range(out_dims[0])], [])
19
+ else:
20
+ out_dims = [out_dim // 2 ** (i + 1) for i in range(dilation_num - 1)]
21
+ out_dims.append(out_dim - sum(out_dims))
22
+ index = []
23
+ starts = [0] + out_dims[:-1]
24
+ lengths = [out_dims[i] // out_dims[-1] for i in range(dilation_num)]
25
+ for i in range(out_dims[-1]):
26
+ for j in range(dilation_num):
27
+ index += list(range(starts[j], starts[j] + lengths[j]))
28
+ starts[j] += lengths[j]
29
+ self.index = index
30
+ assert(len(index) == out_dim)
31
+ self.out_dims = out_dims
32
+ else:
33
+ self.cat_out = False
34
+ self.out_dims = [out_dim] * dilation_num
35
+
36
+ if comb_mode in ('cat_in', 'cat_both'):
37
+ if equal_dim:
38
+ assert in_dim % dilation_num == 0
39
+ in_dims = [in_dim // dilation_num] * dilation_num
40
+ else:
41
+ in_dims = [in_dim // 2 ** (i + 1) for i in range(dilation_num - 1)]
42
+ in_dims.append(in_dim - sum(in_dims))
43
+ self.in_dims = in_dims
44
+ self.cat_in = True
45
+ else:
46
+ self.cat_in = False
47
+ self.in_dims = [in_dim] * dilation_num
48
+
49
+ conv_type = DepthWiseSeperableConv if use_depthwise else nn.Conv2d
50
+ dilation = min_dilation
51
+ for i in range(dilation_num):
52
+ if isinstance(padding, int):
53
+ cur_padding = padding * dilation
54
+ else:
55
+ cur_padding = padding[i]
56
+ convs.append(conv_type(
57
+ self.in_dims[i], self.out_dims[i], kernel_size, padding=cur_padding, dilation=dilation, **kwargs
58
+ ))
59
+ if i > 0 and shared_weights:
60
+ convs[-1].weight = convs[0].weight
61
+ convs[-1].bias = convs[0].bias
62
+ dilation *= 2
63
+ self.convs = nn.ModuleList(convs)
64
+
65
+ self.shuffle_in_channels = shuffle_in_channels
66
+ if self.shuffle_in_channels:
67
+ # shuffle list as shuffling of tensors is nondeterministic
68
+ in_channels_permute = list(range(in_dim))
69
+ random.shuffle(in_channels_permute)
70
+ # save as buffer so it is saved and loaded with checkpoint
71
+ self.register_buffer('in_channels_permute', torch.tensor(in_channels_permute))
72
+
73
+ def forward(self, x):
74
+ if self.shuffle_in_channels:
75
+ x = x[:, self.in_channels_permute]
76
+
77
+ outs = []
78
+ if self.cat_in:
79
+ if self.equal_dim:
80
+ x = x.chunk(len(self.convs), dim=1)
81
+ else:
82
+ new_x = []
83
+ start = 0
84
+ for dim in self.in_dims:
85
+ new_x.append(x[:, start:start+dim])
86
+ start += dim
87
+ x = new_x
88
+ for i, conv in enumerate(self.convs):
89
+ if self.cat_in:
90
+ input = x[i]
91
+ else:
92
+ input = x
93
+ outs.append(conv(input))
94
+ if self.cat_out:
95
+ out = torch.cat(outs, dim=1)[:, self.index]
96
+ else:
97
+ out = sum(outs)
98
+ return out
inpaint/saicinpainting/training/modules/multiscale.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple, Union, Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from saicinpainting.training.modules.base import get_conv_block_ctor, get_activation
8
+ from saicinpainting.training.modules.pix2pixhd import ResnetBlock
9
+
10
+
11
+ class ResNetHead(nn.Module):
12
+ def __init__(self, input_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
13
+ padding_type='reflect', conv_kind='default', activation=nn.ReLU(True)):
14
+ assert (n_blocks >= 0)
15
+ super(ResNetHead, self).__init__()
16
+
17
+ conv_layer = get_conv_block_ctor(conv_kind)
18
+
19
+ model = [nn.ReflectionPad2d(3),
20
+ conv_layer(input_nc, ngf, kernel_size=7, padding=0),
21
+ norm_layer(ngf),
22
+ activation]
23
+
24
+ ### downsample
25
+ for i in range(n_downsampling):
26
+ mult = 2 ** i
27
+ model += [conv_layer(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
28
+ norm_layer(ngf * mult * 2),
29
+ activation]
30
+
31
+ mult = 2 ** n_downsampling
32
+
33
+ ### resnet blocks
34
+ for i in range(n_blocks):
35
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
36
+ conv_kind=conv_kind)]
37
+
38
+ self.model = nn.Sequential(*model)
39
+
40
+ def forward(self, input):
41
+ return self.model(input)
42
+
43
+
44
+ class ResNetTail(nn.Module):
45
+ def __init__(self, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
46
+ padding_type='reflect', conv_kind='default', activation=nn.ReLU(True),
47
+ up_norm_layer=nn.BatchNorm2d, up_activation=nn.ReLU(True), add_out_act=False, out_extra_layers_n=0,
48
+ add_in_proj=None):
49
+ assert (n_blocks >= 0)
50
+ super(ResNetTail, self).__init__()
51
+
52
+ mult = 2 ** n_downsampling
53
+
54
+ model = []
55
+
56
+ if add_in_proj is not None:
57
+ model.append(nn.Conv2d(add_in_proj, ngf * mult, kernel_size=1))
58
+
59
+ ### resnet blocks
60
+ for i in range(n_blocks):
61
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
62
+ conv_kind=conv_kind)]
63
+
64
+ ### upsample
65
+ for i in range(n_downsampling):
66
+ mult = 2 ** (n_downsampling - i)
67
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1,
68
+ output_padding=1),
69
+ up_norm_layer(int(ngf * mult / 2)),
70
+ up_activation]
71
+ self.model = nn.Sequential(*model)
72
+
73
+ out_layers = []
74
+ for _ in range(out_extra_layers_n):
75
+ out_layers += [nn.Conv2d(ngf, ngf, kernel_size=1, padding=0),
76
+ up_norm_layer(ngf),
77
+ up_activation]
78
+ out_layers += [nn.ReflectionPad2d(3),
79
+ nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
80
+
81
+ if add_out_act:
82
+ out_layers.append(get_activation('tanh' if add_out_act is True else add_out_act))
83
+
84
+ self.out_proj = nn.Sequential(*out_layers)
85
+
86
+ def forward(self, input, return_last_act=False):
87
+ features = self.model(input)
88
+ out = self.out_proj(features)
89
+ if return_last_act:
90
+ return out, features
91
+ else:
92
+ return out
93
+
94
+
95
+ class MultiscaleResNet(nn.Module):
96
+ def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=2, n_blocks_head=2, n_blocks_tail=6, n_scales=3,
97
+ norm_layer=nn.BatchNorm2d, padding_type='reflect', conv_kind='default', activation=nn.ReLU(True),
98
+ up_norm_layer=nn.BatchNorm2d, up_activation=nn.ReLU(True), add_out_act=False, out_extra_layers_n=0,
99
+ out_cumulative=False, return_only_hr=False):
100
+ super().__init__()
101
+
102
+ self.heads = nn.ModuleList([ResNetHead(input_nc, ngf=ngf, n_downsampling=n_downsampling,
103
+ n_blocks=n_blocks_head, norm_layer=norm_layer, padding_type=padding_type,
104
+ conv_kind=conv_kind, activation=activation)
105
+ for i in range(n_scales)])
106
+ tail_in_feats = ngf * (2 ** n_downsampling) + ngf
107
+ self.tails = nn.ModuleList([ResNetTail(output_nc,
108
+ ngf=ngf, n_downsampling=n_downsampling,
109
+ n_blocks=n_blocks_tail, norm_layer=norm_layer, padding_type=padding_type,
110
+ conv_kind=conv_kind, activation=activation, up_norm_layer=up_norm_layer,
111
+ up_activation=up_activation, add_out_act=add_out_act,
112
+ out_extra_layers_n=out_extra_layers_n,
113
+ add_in_proj=None if (i == n_scales - 1) else tail_in_feats)
114
+ for i in range(n_scales)])
115
+
116
+ self.out_cumulative = out_cumulative
117
+ self.return_only_hr = return_only_hr
118
+
119
+ @property
120
+ def num_scales(self):
121
+ return len(self.heads)
122
+
123
+ def forward(self, ms_inputs: List[torch.Tensor], smallest_scales_num: Optional[int] = None) \
124
+ -> Union[torch.Tensor, List[torch.Tensor]]:
125
+ """
126
+ :param ms_inputs: List of inputs of different resolutions from HR to LR
127
+ :param smallest_scales_num: int or None, number of smallest scales to take at input
128
+ :return: Depending on return_only_hr:
129
+ True: Only the most HR output
130
+ False: List of outputs of different resolutions from HR to LR
131
+ """
132
+ if smallest_scales_num is None:
133
+ assert len(self.heads) == len(ms_inputs), (len(self.heads), len(ms_inputs), smallest_scales_num)
134
+ smallest_scales_num = len(self.heads)
135
+ else:
136
+ assert smallest_scales_num == len(ms_inputs) <= len(self.heads), (len(self.heads), len(ms_inputs), smallest_scales_num)
137
+
138
+ cur_heads = self.heads[-smallest_scales_num:]
139
+ ms_features = [cur_head(cur_inp) for cur_head, cur_inp in zip(cur_heads, ms_inputs)]
140
+
141
+ all_outputs = []
142
+ prev_tail_features = None
143
+ for i in range(len(ms_features)):
144
+ scale_i = -i - 1
145
+
146
+ cur_tail_input = ms_features[-i - 1]
147
+ if prev_tail_features is not None:
148
+ if prev_tail_features.shape != cur_tail_input.shape:
149
+ prev_tail_features = F.interpolate(prev_tail_features, size=cur_tail_input.shape[2:],
150
+ mode='bilinear', align_corners=False)
151
+ cur_tail_input = torch.cat((cur_tail_input, prev_tail_features), dim=1)
152
+
153
+ cur_out, cur_tail_feats = self.tails[scale_i](cur_tail_input, return_last_act=True)
154
+
155
+ prev_tail_features = cur_tail_feats
156
+ all_outputs.append(cur_out)
157
+
158
+ if self.out_cumulative:
159
+ all_outputs_cum = [all_outputs[0]]
160
+ for i in range(1, len(ms_features)):
161
+ cur_out = all_outputs[i]
162
+ cur_out_cum = cur_out + F.interpolate(all_outputs_cum[-1], size=cur_out.shape[2:],
163
+ mode='bilinear', align_corners=False)
164
+ all_outputs_cum.append(cur_out_cum)
165
+ all_outputs = all_outputs_cum
166
+
167
+ if self.return_only_hr:
168
+ return all_outputs[-1]
169
+ else:
170
+ return all_outputs[::-1]
171
+
172
+
173
+ class MultiscaleDiscriminatorSimple(nn.Module):
174
+ def __init__(self, ms_impl):
175
+ super().__init__()
176
+ self.ms_impl = nn.ModuleList(ms_impl)
177
+
178
+ @property
179
+ def num_scales(self):
180
+ return len(self.ms_impl)
181
+
182
+ def forward(self, ms_inputs: List[torch.Tensor], smallest_scales_num: Optional[int] = None) \
183
+ -> List[Tuple[torch.Tensor, List[torch.Tensor]]]:
184
+ """
185
+ :param ms_inputs: List of inputs of different resolutions from HR to LR
186
+ :param smallest_scales_num: int or None, number of smallest scales to take at input
187
+ :return: List of pairs (prediction, features) for different resolutions from HR to LR
188
+ """
189
+ if smallest_scales_num is None:
190
+ assert len(self.ms_impl) == len(ms_inputs), (len(self.ms_impl), len(ms_inputs), smallest_scales_num)
191
+ smallest_scales_num = len(self.heads)
192
+ else:
193
+ assert smallest_scales_num == len(ms_inputs) <= len(self.ms_impl), \
194
+ (len(self.ms_impl), len(ms_inputs), smallest_scales_num)
195
+
196
+ return [cur_discr(cur_input) for cur_discr, cur_input in zip(self.ms_impl[-smallest_scales_num:], ms_inputs)]
197
+
198
+
199
+ class SingleToMultiScaleInputMixin:
200
+ def forward(self, x: torch.Tensor) -> List:
201
+ orig_height, orig_width = x.shape[2:]
202
+ factors = [2 ** i for i in range(self.num_scales)]
203
+ ms_inputs = [F.interpolate(x, size=(orig_height // f, orig_width // f), mode='bilinear', align_corners=False)
204
+ for f in factors]
205
+ return super().forward(ms_inputs)
206
+
207
+
208
+ class GeneratorMultiToSingleOutputMixin:
209
+ def forward(self, x):
210
+ return super().forward(x)[0]
211
+
212
+
213
+ class DiscriminatorMultiToSingleOutputMixin:
214
+ def forward(self, x):
215
+ out_feat_tuples = super().forward(x)
216
+ return out_feat_tuples[0][0], [f for _, flist in out_feat_tuples for f in flist]
217
+
218
+
219
+ class DiscriminatorMultiToSingleOutputStackedMixin:
220
+ def __init__(self, *args, return_feats_only_levels=None, **kwargs):
221
+ super().__init__(*args, **kwargs)
222
+ self.return_feats_only_levels = return_feats_only_levels
223
+
224
+ def forward(self, x):
225
+ out_feat_tuples = super().forward(x)
226
+ outs = [out for out, _ in out_feat_tuples]
227
+ scaled_outs = [outs[0]] + [F.interpolate(cur_out, size=outs[0].shape[-2:],
228
+ mode='bilinear', align_corners=False)
229
+ for cur_out in outs[1:]]
230
+ out = torch.cat(scaled_outs, dim=1)
231
+ if self.return_feats_only_levels is not None:
232
+ feat_lists = [out_feat_tuples[i][1] for i in self.return_feats_only_levels]
233
+ else:
234
+ feat_lists = [flist for _, flist in out_feat_tuples]
235
+ feats = [f for flist in feat_lists for f in flist]
236
+ return out, feats
237
+
238
+
239
+ class MultiscaleDiscrSingleInput(SingleToMultiScaleInputMixin, DiscriminatorMultiToSingleOutputStackedMixin, MultiscaleDiscriminatorSimple):
240
+ pass
241
+
242
+
243
+ class MultiscaleResNetSingle(GeneratorMultiToSingleOutputMixin, SingleToMultiScaleInputMixin, MultiscaleResNet):
244
+ pass
inpaint/saicinpainting/training/modules/pix2pixhd.py ADDED
@@ -0,0 +1,669 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # original: https://github.com/NVIDIA/pix2pixHD/blob/master/models/networks.py
2
+ import collections
3
+ from functools import partial
4
+ import functools
5
+ import logging
6
+ from collections import defaultdict
7
+
8
+ import numpy as np
9
+ import torch.nn as nn
10
+
11
+ from saicinpainting.training.modules.base import BaseDiscriminator, deconv_factory, get_conv_block_ctor, get_norm_layer, get_activation
12
+ from saicinpainting.training.modules.ffc import FFCResnetBlock
13
+ from saicinpainting.training.modules.multidilated_conv import MultidilatedConv
14
+
15
+ class DotDict(defaultdict):
16
+ # https://stackoverflow.com/questions/2352181/how-to-use-a-dot-to-access-members-of-dictionary
17
+ """dot.notation access to dictionary attributes"""
18
+ __getattr__ = defaultdict.get
19
+ __setattr__ = defaultdict.__setitem__
20
+ __delattr__ = defaultdict.__delitem__
21
+
22
+ class Identity(nn.Module):
23
+ def __init__(self):
24
+ super().__init__()
25
+
26
+ def forward(self, x):
27
+ return x
28
+
29
+
30
+ class ResnetBlock(nn.Module):
31
+ def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False, conv_kind='default',
32
+ dilation=1, in_dim=None, groups=1, second_dilation=None):
33
+ super(ResnetBlock, self).__init__()
34
+ self.in_dim = in_dim
35
+ self.dim = dim
36
+ if second_dilation is None:
37
+ second_dilation = dilation
38
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout,
39
+ conv_kind=conv_kind, dilation=dilation, in_dim=in_dim, groups=groups,
40
+ second_dilation=second_dilation)
41
+
42
+ if self.in_dim is not None:
43
+ self.input_conv = nn.Conv2d(in_dim, dim, 1)
44
+
45
+ self.out_channnels = dim
46
+
47
+ def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout, conv_kind='default',
48
+ dilation=1, in_dim=None, groups=1, second_dilation=1):
49
+ conv_layer = get_conv_block_ctor(conv_kind)
50
+
51
+ conv_block = []
52
+ p = 0
53
+ if padding_type == 'reflect':
54
+ conv_block += [nn.ReflectionPad2d(dilation)]
55
+ elif padding_type == 'replicate':
56
+ conv_block += [nn.ReplicationPad2d(dilation)]
57
+ elif padding_type == 'zero':
58
+ p = dilation
59
+ else:
60
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
61
+
62
+ if in_dim is None:
63
+ in_dim = dim
64
+
65
+ conv_block += [conv_layer(in_dim, dim, kernel_size=3, padding=p, dilation=dilation),
66
+ norm_layer(dim),
67
+ activation]
68
+ if use_dropout:
69
+ conv_block += [nn.Dropout(0.5)]
70
+
71
+ p = 0
72
+ if padding_type == 'reflect':
73
+ conv_block += [nn.ReflectionPad2d(second_dilation)]
74
+ elif padding_type == 'replicate':
75
+ conv_block += [nn.ReplicationPad2d(second_dilation)]
76
+ elif padding_type == 'zero':
77
+ p = second_dilation
78
+ else:
79
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
80
+ conv_block += [conv_layer(dim, dim, kernel_size=3, padding=p, dilation=second_dilation, groups=groups),
81
+ norm_layer(dim)]
82
+
83
+ return nn.Sequential(*conv_block)
84
+
85
+ def forward(self, x):
86
+ x_before = x
87
+ if self.in_dim is not None:
88
+ x = self.input_conv(x)
89
+ out = x + self.conv_block(x_before)
90
+ return out
91
+
92
+ class ResnetBlock5x5(nn.Module):
93
+ def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False, conv_kind='default',
94
+ dilation=1, in_dim=None, groups=1, second_dilation=None):
95
+ super(ResnetBlock5x5, self).__init__()
96
+ self.in_dim = in_dim
97
+ self.dim = dim
98
+ if second_dilation is None:
99
+ second_dilation = dilation
100
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout,
101
+ conv_kind=conv_kind, dilation=dilation, in_dim=in_dim, groups=groups,
102
+ second_dilation=second_dilation)
103
+
104
+ if self.in_dim is not None:
105
+ self.input_conv = nn.Conv2d(in_dim, dim, 1)
106
+
107
+ self.out_channnels = dim
108
+
109
+ def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout, conv_kind='default',
110
+ dilation=1, in_dim=None, groups=1, second_dilation=1):
111
+ conv_layer = get_conv_block_ctor(conv_kind)
112
+
113
+ conv_block = []
114
+ p = 0
115
+ if padding_type == 'reflect':
116
+ conv_block += [nn.ReflectionPad2d(dilation * 2)]
117
+ elif padding_type == 'replicate':
118
+ conv_block += [nn.ReplicationPad2d(dilation * 2)]
119
+ elif padding_type == 'zero':
120
+ p = dilation * 2
121
+ else:
122
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
123
+
124
+ if in_dim is None:
125
+ in_dim = dim
126
+
127
+ conv_block += [conv_layer(in_dim, dim, kernel_size=5, padding=p, dilation=dilation),
128
+ norm_layer(dim),
129
+ activation]
130
+ if use_dropout:
131
+ conv_block += [nn.Dropout(0.5)]
132
+
133
+ p = 0
134
+ if padding_type == 'reflect':
135
+ conv_block += [nn.ReflectionPad2d(second_dilation * 2)]
136
+ elif padding_type == 'replicate':
137
+ conv_block += [nn.ReplicationPad2d(second_dilation * 2)]
138
+ elif padding_type == 'zero':
139
+ p = second_dilation * 2
140
+ else:
141
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
142
+ conv_block += [conv_layer(dim, dim, kernel_size=5, padding=p, dilation=second_dilation, groups=groups),
143
+ norm_layer(dim)]
144
+
145
+ return nn.Sequential(*conv_block)
146
+
147
+ def forward(self, x):
148
+ x_before = x
149
+ if self.in_dim is not None:
150
+ x = self.input_conv(x)
151
+ out = x + self.conv_block(x_before)
152
+ return out
153
+
154
+
155
+ class MultidilatedResnetBlock(nn.Module):
156
+ def __init__(self, dim, padding_type, conv_layer, norm_layer, activation=nn.ReLU(True), use_dropout=False):
157
+ super().__init__()
158
+ self.conv_block = self.build_conv_block(dim, padding_type, conv_layer, norm_layer, activation, use_dropout)
159
+
160
+ def build_conv_block(self, dim, padding_type, conv_layer, norm_layer, activation, use_dropout, dilation=1):
161
+ conv_block = []
162
+ conv_block += [conv_layer(dim, dim, kernel_size=3, padding_mode=padding_type),
163
+ norm_layer(dim),
164
+ activation]
165
+ if use_dropout:
166
+ conv_block += [nn.Dropout(0.5)]
167
+
168
+ conv_block += [conv_layer(dim, dim, kernel_size=3, padding_mode=padding_type),
169
+ norm_layer(dim)]
170
+
171
+ return nn.Sequential(*conv_block)
172
+
173
+ def forward(self, x):
174
+ out = x + self.conv_block(x)
175
+ return out
176
+
177
+
178
+ class MultiDilatedGlobalGenerator(nn.Module):
179
+ def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3,
180
+ n_blocks=3, norm_layer=nn.BatchNorm2d,
181
+ padding_type='reflect', conv_kind='default',
182
+ deconv_kind='convtranspose', activation=nn.ReLU(True),
183
+ up_norm_layer=nn.BatchNorm2d, affine=None, up_activation=nn.ReLU(True),
184
+ add_out_act=True, max_features=1024, multidilation_kwargs={},
185
+ ffc_positions=None, ffc_kwargs={}):
186
+ assert (n_blocks >= 0)
187
+ super().__init__()
188
+
189
+ conv_layer = get_conv_block_ctor(conv_kind)
190
+ resnet_conv_layer = functools.partial(get_conv_block_ctor('multidilated'), **multidilation_kwargs)
191
+ norm_layer = get_norm_layer(norm_layer)
192
+ if affine is not None:
193
+ norm_layer = partial(norm_layer, affine=affine)
194
+ up_norm_layer = get_norm_layer(up_norm_layer)
195
+ if affine is not None:
196
+ up_norm_layer = partial(up_norm_layer, affine=affine)
197
+
198
+ model = [nn.ReflectionPad2d(3),
199
+ conv_layer(input_nc, ngf, kernel_size=7, padding=0),
200
+ norm_layer(ngf),
201
+ activation]
202
+
203
+ identity = Identity()
204
+ ### downsample
205
+ for i in range(n_downsampling):
206
+ mult = 2 ** i
207
+
208
+ model += [conv_layer(min(max_features, ngf * mult),
209
+ min(max_features, ngf * mult * 2),
210
+ kernel_size=3, stride=2, padding=1),
211
+ norm_layer(min(max_features, ngf * mult * 2)),
212
+ activation]
213
+
214
+ mult = 2 ** n_downsampling
215
+ feats_num_bottleneck = min(max_features, ngf * mult)
216
+
217
+ ### resnet blocks
218
+ for i in range(n_blocks):
219
+ if ffc_positions is not None and i in ffc_positions:
220
+ model += [FFCResnetBlock(feats_num_bottleneck, padding_type, norm_layer, activation_layer=nn.ReLU,
221
+ inline=True, **ffc_kwargs)]
222
+ model += [MultidilatedResnetBlock(feats_num_bottleneck, padding_type=padding_type,
223
+ conv_layer=resnet_conv_layer, activation=activation,
224
+ norm_layer=norm_layer)]
225
+
226
+ ### upsample
227
+ for i in range(n_downsampling):
228
+ mult = 2 ** (n_downsampling - i)
229
+ model += deconv_factory(deconv_kind, ngf, mult, up_norm_layer, up_activation, max_features)
230
+ model += [nn.ReflectionPad2d(3),
231
+ nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
232
+ if add_out_act:
233
+ model.append(get_activation('tanh' if add_out_act is True else add_out_act))
234
+ self.model = nn.Sequential(*model)
235
+
236
+ def forward(self, input):
237
+ return self.model(input)
238
+
239
+ class ConfigGlobalGenerator(nn.Module):
240
+ def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3,
241
+ n_blocks=3, norm_layer=nn.BatchNorm2d,
242
+ padding_type='reflect', conv_kind='default',
243
+ deconv_kind='convtranspose', activation=nn.ReLU(True),
244
+ up_norm_layer=nn.BatchNorm2d, affine=None, up_activation=nn.ReLU(True),
245
+ add_out_act=True, max_features=1024,
246
+ manual_block_spec=[],
247
+ resnet_block_kind='multidilatedresnetblock',
248
+ resnet_conv_kind='multidilated',
249
+ resnet_dilation=1,
250
+ multidilation_kwargs={}):
251
+ assert (n_blocks >= 0)
252
+ super().__init__()
253
+
254
+ conv_layer = get_conv_block_ctor(conv_kind)
255
+ resnet_conv_layer = functools.partial(get_conv_block_ctor(resnet_conv_kind), **multidilation_kwargs)
256
+ norm_layer = get_norm_layer(norm_layer)
257
+ if affine is not None:
258
+ norm_layer = partial(norm_layer, affine=affine)
259
+ up_norm_layer = get_norm_layer(up_norm_layer)
260
+ if affine is not None:
261
+ up_norm_layer = partial(up_norm_layer, affine=affine)
262
+
263
+ model = [nn.ReflectionPad2d(3),
264
+ conv_layer(input_nc, ngf, kernel_size=7, padding=0),
265
+ norm_layer(ngf),
266
+ activation]
267
+
268
+ identity = Identity()
269
+
270
+ ### downsample
271
+ for i in range(n_downsampling):
272
+ mult = 2 ** i
273
+ model += [conv_layer(min(max_features, ngf * mult),
274
+ min(max_features, ngf * mult * 2),
275
+ kernel_size=3, stride=2, padding=1),
276
+ norm_layer(min(max_features, ngf * mult * 2)),
277
+ activation]
278
+
279
+ mult = 2 ** n_downsampling
280
+ feats_num_bottleneck = min(max_features, ngf * mult)
281
+
282
+ if len(manual_block_spec) == 0:
283
+ manual_block_spec = [
284
+ DotDict(lambda : None, {
285
+ 'n_blocks': n_blocks,
286
+ 'use_default': True})
287
+ ]
288
+
289
+ ### resnet blocks
290
+ for block_spec in manual_block_spec:
291
+ def make_and_add_blocks(model, block_spec):
292
+ block_spec = DotDict(lambda : None, block_spec)
293
+ if not block_spec.use_default:
294
+ resnet_conv_layer = functools.partial(get_conv_block_ctor(block_spec.resnet_conv_kind), **block_spec.multidilation_kwargs)
295
+ resnet_conv_kind = block_spec.resnet_conv_kind
296
+ resnet_block_kind = block_spec.resnet_block_kind
297
+ if block_spec.resnet_dilation is not None:
298
+ resnet_dilation = block_spec.resnet_dilation
299
+ for i in range(block_spec.n_blocks):
300
+ if resnet_block_kind == "multidilatedresnetblock":
301
+ model += [MultidilatedResnetBlock(feats_num_bottleneck, padding_type=padding_type,
302
+ conv_layer=resnet_conv_layer, activation=activation,
303
+ norm_layer=norm_layer)]
304
+ if resnet_block_kind == "resnetblock":
305
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
306
+ conv_kind=resnet_conv_kind)]
307
+ if resnet_block_kind == "resnetblock5x5":
308
+ model += [ResnetBlock5x5(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
309
+ conv_kind=resnet_conv_kind)]
310
+ if resnet_block_kind == "resnetblockdwdil":
311
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
312
+ conv_kind=resnet_conv_kind, dilation=resnet_dilation, second_dilation=resnet_dilation)]
313
+ make_and_add_blocks(model, block_spec)
314
+
315
+ ### upsample
316
+ for i in range(n_downsampling):
317
+ mult = 2 ** (n_downsampling - i)
318
+ model += deconv_factory(deconv_kind, ngf, mult, up_norm_layer, up_activation, max_features)
319
+ model += [nn.ReflectionPad2d(3),
320
+ nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
321
+ if add_out_act:
322
+ model.append(get_activation('tanh' if add_out_act is True else add_out_act))
323
+ self.model = nn.Sequential(*model)
324
+
325
+ def forward(self, input):
326
+ return self.model(input)
327
+
328
+
329
+ def make_dil_blocks(dilated_blocks_n, dilation_block_kind, dilated_block_kwargs):
330
+ blocks = []
331
+ for i in range(dilated_blocks_n):
332
+ if dilation_block_kind == 'simple':
333
+ blocks.append(ResnetBlock(**dilated_block_kwargs, dilation=2 ** (i + 1)))
334
+ elif dilation_block_kind == 'multi':
335
+ blocks.append(MultidilatedResnetBlock(**dilated_block_kwargs))
336
+ else:
337
+ raise ValueError(f'dilation_block_kind could not be "{dilation_block_kind}"')
338
+ return blocks
339
+
340
+
341
+ class GlobalGenerator(nn.Module):
342
+ def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
343
+ padding_type='reflect', conv_kind='default', activation=nn.ReLU(True),
344
+ up_norm_layer=nn.BatchNorm2d, affine=None,
345
+ up_activation=nn.ReLU(True), dilated_blocks_n=0, dilated_blocks_n_start=0,
346
+ dilated_blocks_n_middle=0,
347
+ add_out_act=True,
348
+ max_features=1024, is_resblock_depthwise=False,
349
+ ffc_positions=None, ffc_kwargs={}, dilation=1, second_dilation=None,
350
+ dilation_block_kind='simple', multidilation_kwargs={}):
351
+ assert (n_blocks >= 0)
352
+ super().__init__()
353
+
354
+ conv_layer = get_conv_block_ctor(conv_kind)
355
+ norm_layer = get_norm_layer(norm_layer)
356
+ if affine is not None:
357
+ norm_layer = partial(norm_layer, affine=affine)
358
+ up_norm_layer = get_norm_layer(up_norm_layer)
359
+ if affine is not None:
360
+ up_norm_layer = partial(up_norm_layer, affine=affine)
361
+
362
+ if ffc_positions is not None:
363
+ ffc_positions = collections.Counter(ffc_positions)
364
+
365
+ model = [nn.ReflectionPad2d(3),
366
+ conv_layer(input_nc, ngf, kernel_size=7, padding=0),
367
+ norm_layer(ngf),
368
+ activation]
369
+
370
+ identity = Identity()
371
+ ### downsample
372
+ for i in range(n_downsampling):
373
+ mult = 2 ** i
374
+
375
+ model += [conv_layer(min(max_features, ngf * mult),
376
+ min(max_features, ngf * mult * 2),
377
+ kernel_size=3, stride=2, padding=1),
378
+ norm_layer(min(max_features, ngf * mult * 2)),
379
+ activation]
380
+
381
+ mult = 2 ** n_downsampling
382
+ feats_num_bottleneck = min(max_features, ngf * mult)
383
+
384
+ dilated_block_kwargs = dict(dim=feats_num_bottleneck, padding_type=padding_type,
385
+ activation=activation, norm_layer=norm_layer)
386
+ if dilation_block_kind == 'simple':
387
+ dilated_block_kwargs['conv_kind'] = conv_kind
388
+ elif dilation_block_kind == 'multi':
389
+ dilated_block_kwargs['conv_layer'] = functools.partial(
390
+ get_conv_block_ctor('multidilated'), **multidilation_kwargs)
391
+
392
+ # dilated blocks at the start of the bottleneck sausage
393
+ if dilated_blocks_n_start is not None and dilated_blocks_n_start > 0:
394
+ model += make_dil_blocks(dilated_blocks_n_start, dilation_block_kind, dilated_block_kwargs)
395
+
396
+ # resnet blocks
397
+ for i in range(n_blocks):
398
+ # dilated blocks at the middle of the bottleneck sausage
399
+ if i == n_blocks // 2 and dilated_blocks_n_middle is not None and dilated_blocks_n_middle > 0:
400
+ model += make_dil_blocks(dilated_blocks_n_middle, dilation_block_kind, dilated_block_kwargs)
401
+
402
+ if ffc_positions is not None and i in ffc_positions:
403
+ for _ in range(ffc_positions[i]): # same position can occur more than once
404
+ model += [FFCResnetBlock(feats_num_bottleneck, padding_type, norm_layer, activation_layer=nn.ReLU,
405
+ inline=True, **ffc_kwargs)]
406
+
407
+ if is_resblock_depthwise:
408
+ resblock_groups = feats_num_bottleneck
409
+ else:
410
+ resblock_groups = 1
411
+
412
+ model += [ResnetBlock(feats_num_bottleneck, padding_type=padding_type, activation=activation,
413
+ norm_layer=norm_layer, conv_kind=conv_kind, groups=resblock_groups,
414
+ dilation=dilation, second_dilation=second_dilation)]
415
+
416
+
417
+ # dilated blocks at the end of the bottleneck sausage
418
+ if dilated_blocks_n is not None and dilated_blocks_n > 0:
419
+ model += make_dil_blocks(dilated_blocks_n, dilation_block_kind, dilated_block_kwargs)
420
+
421
+ # upsample
422
+ for i in range(n_downsampling):
423
+ mult = 2 ** (n_downsampling - i)
424
+ model += [nn.ConvTranspose2d(min(max_features, ngf * mult),
425
+ min(max_features, int(ngf * mult / 2)),
426
+ kernel_size=3, stride=2, padding=1, output_padding=1),
427
+ up_norm_layer(min(max_features, int(ngf * mult / 2))),
428
+ up_activation]
429
+ model += [nn.ReflectionPad2d(3),
430
+ nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
431
+ if add_out_act:
432
+ model.append(get_activation('tanh' if add_out_act is True else add_out_act))
433
+ self.model = nn.Sequential(*model)
434
+
435
+ def forward(self, input):
436
+ return self.model(input)
437
+
438
+
439
+ class GlobalGeneratorGated(GlobalGenerator):
440
+ def __init__(self, *args, **kwargs):
441
+ real_kwargs=dict(
442
+ conv_kind='gated_bn_relu',
443
+ activation=nn.Identity(),
444
+ norm_layer=nn.Identity
445
+ )
446
+ real_kwargs.update(kwargs)
447
+ super().__init__(*args, **real_kwargs)
448
+
449
+
450
+ class GlobalGeneratorFromSuperChannels(nn.Module):
451
+ def __init__(self, input_nc, output_nc, n_downsampling, n_blocks, super_channels, norm_layer="bn", padding_type='reflect', add_out_act=True):
452
+ super().__init__()
453
+ self.n_downsampling = n_downsampling
454
+ norm_layer = get_norm_layer(norm_layer)
455
+ if type(norm_layer) == functools.partial:
456
+ use_bias = (norm_layer.func == nn.InstanceNorm2d)
457
+ else:
458
+ use_bias = (norm_layer == nn.InstanceNorm2d)
459
+
460
+ channels = self.convert_super_channels(super_channels)
461
+ self.channels = channels
462
+
463
+ model = [nn.ReflectionPad2d(3),
464
+ nn.Conv2d(input_nc, channels[0], kernel_size=7, padding=0, bias=use_bias),
465
+ norm_layer(channels[0]),
466
+ nn.ReLU(True)]
467
+
468
+ for i in range(n_downsampling): # add downsampling layers
469
+ mult = 2 ** i
470
+ model += [nn.Conv2d(channels[0+i], channels[1+i], kernel_size=3, stride=2, padding=1, bias=use_bias),
471
+ norm_layer(channels[1+i]),
472
+ nn.ReLU(True)]
473
+
474
+ mult = 2 ** n_downsampling
475
+
476
+ n_blocks1 = n_blocks // 3
477
+ n_blocks2 = n_blocks1
478
+ n_blocks3 = n_blocks - n_blocks1 - n_blocks2
479
+
480
+ for i in range(n_blocks1):
481
+ c = n_downsampling
482
+ dim = channels[c]
483
+ model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer)]
484
+
485
+ for i in range(n_blocks2):
486
+ c = n_downsampling+1
487
+ dim = channels[c]
488
+ kwargs = {}
489
+ if i == 0:
490
+ kwargs = {"in_dim": channels[c-1]}
491
+ model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer, **kwargs)]
492
+
493
+ for i in range(n_blocks3):
494
+ c = n_downsampling+2
495
+ dim = channels[c]
496
+ kwargs = {}
497
+ if i == 0:
498
+ kwargs = {"in_dim": channels[c-1]}
499
+ model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer, **kwargs)]
500
+
501
+ for i in range(n_downsampling): # add upsampling layers
502
+ mult = 2 ** (n_downsampling - i)
503
+ model += [nn.ConvTranspose2d(channels[n_downsampling+3+i],
504
+ channels[n_downsampling+3+i+1],
505
+ kernel_size=3, stride=2,
506
+ padding=1, output_padding=1,
507
+ bias=use_bias),
508
+ norm_layer(channels[n_downsampling+3+i+1]),
509
+ nn.ReLU(True)]
510
+ model += [nn.ReflectionPad2d(3)]
511
+ model += [nn.Conv2d(channels[2*n_downsampling+3], output_nc, kernel_size=7, padding=0)]
512
+
513
+ if add_out_act:
514
+ model.append(get_activation('tanh' if add_out_act is True else add_out_act))
515
+ self.model = nn.Sequential(*model)
516
+
517
+ def convert_super_channels(self, super_channels):
518
+ n_downsampling = self.n_downsampling
519
+ result = []
520
+ cnt = 0
521
+
522
+ if n_downsampling == 2:
523
+ N1 = 10
524
+ elif n_downsampling == 3:
525
+ N1 = 13
526
+ else:
527
+ raise NotImplementedError
528
+
529
+ for i in range(0, N1):
530
+ if i in [1,4,7,10]:
531
+ channel = super_channels[cnt] * (2 ** cnt)
532
+ config = {'channel': channel}
533
+ result.append(channel)
534
+ logging.info(f"Downsample channels {result[-1]}")
535
+ cnt += 1
536
+
537
+ for i in range(3):
538
+ for counter, j in enumerate(range(N1 + i * 3, N1 + 3 + i * 3)):
539
+ if len(super_channels) == 6:
540
+ channel = super_channels[3] * 4
541
+ else:
542
+ channel = super_channels[i + 3] * 4
543
+ config = {'channel': channel}
544
+ if counter == 0:
545
+ result.append(channel)
546
+ logging.info(f"Bottleneck channels {result[-1]}")
547
+ cnt = 2
548
+
549
+ for i in range(N1+9, N1+21):
550
+ if i in [22, 25,28]:
551
+ cnt -= 1
552
+ if len(super_channels) == 6:
553
+ channel = super_channels[5 - cnt] * (2 ** cnt)
554
+ else:
555
+ channel = super_channels[7 - cnt] * (2 ** cnt)
556
+ result.append(int(channel))
557
+ logging.info(f"Upsample channels {result[-1]}")
558
+ return result
559
+
560
+ def forward(self, input):
561
+ return self.model(input)
562
+
563
+
564
+ # Defines the PatchGAN discriminator with the specified arguments.
565
+ class NLayerDiscriminator(BaseDiscriminator):
566
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,):
567
+ super().__init__()
568
+ self.n_layers = n_layers
569
+
570
+ kw = 4
571
+ padw = int(np.ceil((kw-1.0)/2))
572
+ sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
573
+ nn.LeakyReLU(0.2, True)]]
574
+
575
+ nf = ndf
576
+ for n in range(1, n_layers):
577
+ nf_prev = nf
578
+ nf = min(nf * 2, 512)
579
+
580
+ cur_model = []
581
+ cur_model += [
582
+ nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
583
+ norm_layer(nf),
584
+ nn.LeakyReLU(0.2, True)
585
+ ]
586
+ sequence.append(cur_model)
587
+
588
+ nf_prev = nf
589
+ nf = min(nf * 2, 512)
590
+
591
+ cur_model = []
592
+ cur_model += [
593
+ nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
594
+ norm_layer(nf),
595
+ nn.LeakyReLU(0.2, True)
596
+ ]
597
+ sequence.append(cur_model)
598
+
599
+ sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
600
+
601
+ for n in range(len(sequence)):
602
+ setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
603
+
604
+ def get_all_activations(self, x):
605
+ res = [x]
606
+ for n in range(self.n_layers + 2):
607
+ model = getattr(self, 'model' + str(n))
608
+ res.append(model(res[-1]))
609
+ return res[1:]
610
+
611
+ def forward(self, x):
612
+ act = self.get_all_activations(x)
613
+ return act[-1], act[:-1]
614
+
615
+
616
+ class MultidilatedNLayerDiscriminator(BaseDiscriminator):
617
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, multidilation_kwargs={}):
618
+ super().__init__()
619
+ self.n_layers = n_layers
620
+
621
+ kw = 4
622
+ padw = int(np.ceil((kw-1.0)/2))
623
+ sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
624
+ nn.LeakyReLU(0.2, True)]]
625
+
626
+ nf = ndf
627
+ for n in range(1, n_layers):
628
+ nf_prev = nf
629
+ nf = min(nf * 2, 512)
630
+
631
+ cur_model = []
632
+ cur_model += [
633
+ MultidilatedConv(nf_prev, nf, kernel_size=kw, stride=2, padding=[2, 3], **multidilation_kwargs),
634
+ norm_layer(nf),
635
+ nn.LeakyReLU(0.2, True)
636
+ ]
637
+ sequence.append(cur_model)
638
+
639
+ nf_prev = nf
640
+ nf = min(nf * 2, 512)
641
+
642
+ cur_model = []
643
+ cur_model += [
644
+ nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
645
+ norm_layer(nf),
646
+ nn.LeakyReLU(0.2, True)
647
+ ]
648
+ sequence.append(cur_model)
649
+
650
+ sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
651
+
652
+ for n in range(len(sequence)):
653
+ setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
654
+
655
+ def get_all_activations(self, x):
656
+ res = [x]
657
+ for n in range(self.n_layers + 2):
658
+ model = getattr(self, 'model' + str(n))
659
+ res.append(model(res[-1]))
660
+ return res[1:]
661
+
662
+ def forward(self, x):
663
+ act = self.get_all_activations(x)
664
+ return act[-1], act[:-1]
665
+
666
+
667
+ class NLayerDiscriminatorAsGen(NLayerDiscriminator):
668
+ def forward(self, x):
669
+ return super().forward(x)[0]
inpaint/saicinpainting/training/modules/spatial_transform.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from kornia.geometry.transform import rotate
5
+
6
+
7
+ class LearnableSpatialTransformWrapper(nn.Module):
8
+ def __init__(self, impl, pad_coef=0.5, angle_init_range=80, train_angle=True):
9
+ super().__init__()
10
+ self.impl = impl
11
+ self.angle = torch.rand(1) * angle_init_range
12
+ if train_angle:
13
+ self.angle = nn.Parameter(self.angle, requires_grad=True)
14
+ self.pad_coef = pad_coef
15
+
16
+ def forward(self, x):
17
+ if torch.is_tensor(x):
18
+ return self.inverse_transform(self.impl(self.transform(x)), x)
19
+ elif isinstance(x, tuple):
20
+ x_trans = tuple(self.transform(elem) for elem in x)
21
+ y_trans = self.impl(x_trans)
22
+ return tuple(self.inverse_transform(elem, orig_x) for elem, orig_x in zip(y_trans, x))
23
+ else:
24
+ raise ValueError(f'Unexpected input type {type(x)}')
25
+
26
+ def transform(self, x):
27
+ height, width = x.shape[2:]
28
+ pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef)
29
+ x_padded = F.pad(x, [pad_w, pad_w, pad_h, pad_h], mode='reflect')
30
+ x_padded_rotated = rotate(x_padded, angle=self.angle.to(x_padded))
31
+ return x_padded_rotated
32
+
33
+ def inverse_transform(self, y_padded_rotated, orig_x):
34
+ height, width = orig_x.shape[2:]
35
+ pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef)
36
+
37
+ y_padded = rotate(y_padded_rotated, angle=-self.angle.to(y_padded_rotated))
38
+ y_height, y_width = y_padded.shape[2:]
39
+ y = y_padded[:, :, pad_h : y_height - pad_h, pad_w : y_width - pad_w]
40
+ return y
41
+
42
+
43
+ if __name__ == '__main__':
44
+ layer = LearnableSpatialTransformWrapper(nn.Identity())
45
+ x = torch.arange(2* 3 * 15 * 15).view(2, 3, 15, 15).float()
46
+ y = layer(x)
47
+ assert x.shape == y.shape
48
+ assert torch.allclose(x[:, :, 1:, 1:][:, :, :-1, :-1], y[:, :, 1:, 1:][:, :, :-1, :-1])
49
+ print('all ok')
inpaint/saicinpainting/training/modules/squeeze_excitation.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ class SELayer(nn.Module):
5
+ def __init__(self, channel, reduction=16):
6
+ super(SELayer, self).__init__()
7
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
8
+ self.fc = nn.Sequential(
9
+ nn.Linear(channel, channel // reduction, bias=False),
10
+ nn.ReLU(inplace=True),
11
+ nn.Linear(channel // reduction, channel, bias=False),
12
+ nn.Sigmoid()
13
+ )
14
+
15
+ def forward(self, x):
16
+ b, c, _, _ = x.size()
17
+ y = self.avg_pool(x).view(b, c)
18
+ y = self.fc(y).view(b, c, 1, 1)
19
+ res = x * y.expand_as(x)
20
+ return res
inpaint/saicinpainting/training/trainers/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from saicinpainting.training.trainers.default import DefaultInpaintingTrainingModule
3
+
4
+
5
+ def get_training_model_class(kind):
6
+ if kind == 'default':
7
+ return DefaultInpaintingTrainingModule
8
+
9
+ raise ValueError(f'Unknown trainer module {kind}')
10
+
11
+
12
+ def make_training_model(config):
13
+ kind = config.training_model.kind
14
+ kwargs = dict(config.training_model)
15
+ kwargs.pop('kind')
16
+ kwargs['use_ddp'] = config.trainer.kwargs.get('accelerator', None) == 'ddp'
17
+ cls = get_training_model_class(kind)
18
+ return cls(config, **kwargs)
19
+
20
+
21
+ def load_checkpoint(train_config, path, map_location='cuda', strict=True):
22
+ model: torch.nn.Module = make_training_model(train_config)
23
+ state = torch.load(path, map_location=map_location)
24
+ model.load_state_dict(state['state_dict'], strict=strict)
25
+ model.on_load_checkpoint(state)
26
+ return model
inpaint/saicinpainting/training/trainers/base.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as ptl
2
+ from inpaint.saicinpainting.training.modules import make_generator
3
+
4
+
5
+
6
+
7
+ class BaseInpaintingTrainingModule(ptl.LightningModule):
8
+ def __init__(self, config, use_ddp, *args, predict_only=False, visualize_each_iters=100,
9
+ average_generator=False, generator_avg_beta=0.999, average_generator_start_step=30000,
10
+ average_generator_period=10, store_discr_outputs_for_vis=False,
11
+ **kwargs):
12
+ super().__init__(*args, **kwargs)
13
+
14
+ self.config = config
15
+ self.generator = make_generator(config, **self.config.generator)
16
+ self.use_ddp = use_ddp
17
+ self.visualize_each_iters = visualize_each_iters
18
+
19
+
inpaint/saicinpainting/training/trainers/default.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+
4
+ from saicinpainting.training.trainers.base import BaseInpaintingTrainingModule
5
+
6
+
7
+
8
+
9
+ class DefaultInpaintingTrainingModule(BaseInpaintingTrainingModule):
10
+ def __init__(self, *args, concat_mask=True, rescale_scheduler_kwargs=None, image_to_discriminator='predicted_image',
11
+ add_noise_kwargs=None, noise_fill_hole=False, const_area_crop_kwargs=None,
12
+ distance_weighter_kwargs=None, distance_weighted_mask_for_discr=False,
13
+ fake_fakes_proba=0, fake_fakes_generator_kwargs=None,
14
+ **kwargs):
15
+ super().__init__(*args, **kwargs)
16
+ self.concat_mask = concat_mask
17
+ self.image_to_discriminator = image_to_discriminator
18
+ self.add_noise_kwargs = add_noise_kwargs
19
+ self.noise_fill_hole = noise_fill_hole
20
+ self.const_area_crop_kwargs = const_area_crop_kwargs
21
+ # print(distance_weighter_kwargs)
22
+ self.refine_mask_for_losses = None
23
+ self.distance_weighted_mask_for_discr = distance_weighted_mask_for_discr
24
+
25
+ self.fake_fakes_proba = fake_fakes_proba
26
+
27
+ def forward(self, batch):
28
+
29
+ img = batch['image']
30
+ mask = batch['mask']
31
+
32
+ masked_img = img * (1 - mask)
33
+ if self.concat_mask:
34
+ masked_img = torch.cat([masked_img, mask], dim=1)
35
+
36
+ batch['predicted_image'] = self.generator(masked_img)
37
+ batch['inpainted'] = mask * batch['predicted_image'] + (1 - mask) * batch['image']
38
+ if self.fake_fakes_proba > 1e-3:
39
+ if self.training and torch.rand(1).item() < self.fake_fakes_proba:
40
+ batch['fake_fakes'], batch['fake_fakes_masks'] = self.fake_fakes_gen(img, mask)
41
+ batch['use_fake_fakes'] = True
42
+ else:
43
+ batch['fake_fakes'] = torch.zeros_like(img)
44
+ batch['fake_fakes_masks'] = torch.zeros_like(mask)
45
+ batch['use_fake_fakes'] = False
46
+
47
+ batch['mask_for_losses'] = self.refine_mask_for_losses(img, batch['predicted_image'], mask) \
48
+ if self.refine_mask_for_losses is not None and self.training \
49
+ else mask
50
+
51
+ return batch
52
+
53
+
sod/PGNet.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from .Res import resnet18
6
+ from .Swin import Swintransformer
7
+ Act = nn.ReLU
8
+
9
+
10
+ def weight_init(module):
11
+ for n, m in module.named_children():
12
+ if isinstance(m, nn.Conv2d):
13
+ nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
14
+ if m.bias is not None:
15
+ nn.init.zeros_(m.bias)
16
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d,nn.BatchNorm1d)):
17
+ nn.init.ones_(m.weight)
18
+ if m.bias is not None:
19
+ nn.init.zeros_(m.bias)
20
+ elif isinstance(m, nn.Linear):
21
+ nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
22
+ if m.bias is not None:
23
+ nn.init.zeros_(m.bias)
24
+ elif isinstance(m, nn.Sequential):
25
+ weight_init(m)
26
+ elif isinstance(m, nn.LayerNorm):
27
+ nn.init.constant_(m.bias, 0)
28
+ nn.init.constant_(m.weight, 1.0)
29
+ elif isinstance(m, (nn.ReLU,Act,nn.AdaptiveAvgPool2d,nn.Softmax)):
30
+ pass
31
+ else:
32
+ m.initialize()
33
+
34
+
35
+ class Grafting(nn.Module):
36
+ def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None):
37
+ super().__init__()
38
+ self.num_heads = num_heads
39
+ head_dim = dim // num_heads
40
+ self.scale = qk_scale or head_dim ** -0.5
41
+ self.k = nn.Linear(dim, dim , bias=qkv_bias)
42
+ self.qv = nn.Linear(dim, dim * 2, bias=qkv_bias)
43
+ self.proj = nn.Linear(dim, dim)
44
+ self.act = nn.ReLU(inplace=True)
45
+ self.conv = nn.Conv2d(8,8,kernel_size=3, stride=1, padding=1)
46
+ self.lnx = nn.LayerNorm(64)
47
+ self.lny = nn.LayerNorm(64)
48
+ self.bn = nn.BatchNorm2d(8)
49
+ self.conv2 = nn.Sequential(
50
+ nn.Conv2d(64,64,kernel_size=3, stride=1, padding=1),
51
+ nn.BatchNorm2d(64),
52
+ nn.ReLU(inplace=True),
53
+ nn.Conv2d(64,64,kernel_size=3, stride=1, padding=1),
54
+ nn.BatchNorm2d(64),
55
+ nn.ReLU(inplace=True)
56
+ )
57
+ def forward(self, x, y):
58
+ batch_size = x.shape[0]
59
+ chanel = x.shape[1]
60
+ sc = x
61
+ x = x.view(batch_size, chanel, -1).permute(0, 2, 1)
62
+ sc1 = x
63
+ x = self.lnx(x)
64
+ y = y.view(batch_size, chanel, -1).permute(0, 2, 1)
65
+ y = self.lny(y)
66
+
67
+ B, N, C = x.shape
68
+ y_k = self.k(y).reshape(B, N, 1, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
69
+ x_qv= self.qv(x).reshape(B,N,2,self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
70
+ x_q, x_v = x_qv[0], x_qv[1]
71
+ y_k = y_k[0]
72
+ attn = (x_q @ y_k.transpose(-2, -1)) * self.scale
73
+ attn = attn.softmax(dim=-1)
74
+ x = (attn @ x_v).transpose(1, 2).reshape(B, N, C)
75
+
76
+ x = self.proj(x)
77
+ x = (x+sc1)
78
+
79
+ x = x.permute(0,2,1)
80
+ x = x.view(batch_size,chanel,*sc.size()[2:])
81
+ x = self.conv2(x)+x
82
+ return x,self.act(self.bn(self.conv(attn+attn.transpose(-1,-2))))
83
+
84
+
85
+ def initialize(self):
86
+ weight_init(self)
87
+
88
+ class DB1(nn.Module):
89
+ def __init__(self,inplanes,outplanes):
90
+ super(DB1,self).__init__()
91
+ self.squeeze1 = nn.Sequential(
92
+ nn.Conv2d(inplanes, outplanes,kernel_size=1,stride=1,padding=0),
93
+ nn.BatchNorm2d(64),
94
+ nn.ReLU(inplace=True)
95
+ )
96
+ self.squeeze2 = nn.Sequential(
97
+ nn.Conv2d(64, 64, kernel_size=3,stride=1,dilation=2,padding=2),
98
+ nn.BatchNorm2d(64),
99
+ nn.ReLU(inplace=True)
100
+ )
101
+
102
+ def forward(self, x):
103
+ z = self.squeeze2(self.squeeze1(x))
104
+ return z,z
105
+
106
+ def initialize(self):
107
+ weight_init(self)
108
+
109
+ class DB2(nn.Module):
110
+ def __init__(self,inplanes,outplanes):
111
+ super(DB2,self).__init__()
112
+ self.short_cut = nn.Conv2d(outplanes, outplanes, kernel_size=1, stride=1, padding=0)
113
+ self.conv = nn.Sequential(
114
+ nn.Conv2d(inplanes+outplanes,outplanes,kernel_size=3, stride=1, padding=1),
115
+ nn.BatchNorm2d(outplanes),
116
+ nn.ReLU(inplace=True),
117
+ nn.Conv2d(outplanes,outplanes,kernel_size=3, stride=1, padding=1),
118
+ nn.BatchNorm2d(outplanes),
119
+ nn.ReLU(inplace=True)
120
+ )
121
+ self.conv2 = nn.Sequential(
122
+ nn.Conv2d(outplanes,outplanes,kernel_size=3, stride=1, padding=1),
123
+ nn.BatchNorm2d(outplanes),
124
+ nn.ReLU(inplace=True),
125
+ nn.Conv2d(outplanes,outplanes,kernel_size=3, stride=1, padding=1),
126
+ nn.BatchNorm2d(outplanes),
127
+ nn.ReLU(inplace=True)
128
+ )
129
+
130
+ def forward(self,x,z):
131
+ z = F.interpolate(z,size=x.size()[2:],mode='bilinear',align_corners=True)
132
+ p = self.conv(torch.cat((x,z),1))
133
+ sc = self.short_cut(z)
134
+ p = p+sc
135
+ p2 = self.conv2(p)
136
+ p = p+p2
137
+ return p,p
138
+
139
+ def initialize(self):
140
+ weight_init(self)
141
+
142
+ class DB3(nn.Module):
143
+ def __init__(self) -> None:
144
+ super(DB3,self).__init__()
145
+
146
+ self.db2 = DB2(64,64)
147
+
148
+ self.conv3x3 = nn.Sequential(
149
+ nn.Conv2d(64,64,kernel_size=3, stride=1, padding=1),
150
+ nn.BatchNorm2d(64),
151
+ nn.ReLU(inplace=True)
152
+ )
153
+ self.sqz_r4 = nn.Sequential(
154
+ nn.Conv2d(256, 64, kernel_size=3,stride=1,dilation=1,padding=1),
155
+ nn.BatchNorm2d(64),
156
+ nn.ReLU(inplace=True)
157
+ )
158
+
159
+ self.sqz_s1=nn.Sequential(
160
+ nn.Conv2d(128, 64, kernel_size=3,stride=1,dilation=1,padding=1),
161
+ nn.BatchNorm2d(64),
162
+ nn.ReLU(inplace=True)
163
+ )
164
+ def forward(self,s,r,up):
165
+ up = F.interpolate(up,size=s.size()[2:],mode='bilinear',align_corners=True)
166
+ s = self.sqz_s1(s)
167
+ r = self.sqz_r4(r)
168
+ sr = self.conv3x3(s+r)
169
+ out,_ =self.db2(sr,up)
170
+ return out,out
171
+ def initialize(self):
172
+ weight_init(self)
173
+
174
+
175
+
176
+ class decoder(nn.Module):
177
+ def __init__(self) -> None:
178
+ super(decoder,self).__init__()
179
+ self.sqz_s2=nn.Sequential(
180
+ nn.Conv2d(256, 64, kernel_size=3,stride=1,dilation=1,padding=1),
181
+ nn.BatchNorm2d(64),
182
+ nn.ReLU(inplace=True)
183
+ )
184
+ self.sqz_r5 = nn.Sequential(
185
+ nn.Conv2d(512, 64, kernel_size=3,stride=1,dilation=1,padding=1),
186
+ nn.BatchNorm2d(64),
187
+ nn.ReLU(inplace=True)
188
+ )
189
+
190
+ self.GF = Grafting(64,num_heads=8)
191
+ self.d1 = DB1(512,64)
192
+ self.d2 = DB2(512,64)
193
+ self.d3 = DB2(64,64)
194
+ self.d4 = DB3()
195
+ self.d5 = DB2(128,64)
196
+ self.d6 = DB2(64,64)
197
+
198
+ def forward(self,s1,s2,s3,s4,r2,r3,r4,r5):
199
+ r5 = F.interpolate(r5,size = s2.size()[2:],mode='bilinear',align_corners=True)
200
+ s1 = F.interpolate(s1,size = r4.size()[2:],mode='bilinear',align_corners=True)
201
+
202
+ s4_,_ = self.d1(s4)
203
+ s3_,_ = self.d2(s3,s4_)
204
+
205
+ s2_ = self.sqz_s2(s2)
206
+ r5_= self.sqz_r5(r5)
207
+ graft_feature_r5,cam = self.GF(r5_,s2_)
208
+
209
+ graft_feature_r5_,_=self.d3(graft_feature_r5,s3_)
210
+
211
+ graft_feature_r4,_=self.d4(s1,r4,graft_feature_r5_)
212
+
213
+ r3_,_ = self.d5(r3,graft_feature_r4)
214
+
215
+ r2_,_ = self.d6(r2,r3_)
216
+
217
+ return r2_,cam,r5_,s2_
218
+
219
+ def initialize(self):
220
+ weight_init(self)
221
+
222
+
223
+
224
+
225
+ class PGNet(nn.Module):
226
+ def __init__(self, cfg=None):
227
+ super(PGNet, self).__init__()
228
+ self.cfg = cfg
229
+ self.decoder = decoder()
230
+ self.linear1 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1)
231
+ self.linear2 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1)
232
+ self.linear3 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1)
233
+ self.conv = nn.Conv2d(8,1,kernel_size=3, stride=1, padding=1)
234
+
235
+
236
+ if self.cfg is None or self.cfg.snapshot is None:
237
+ weight_init(self)
238
+
239
+ self.resnet = resnet18()
240
+ self.swin = Swintransformer(224)
241
+ device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
242
+ self.swin.load_state_dict(torch.load('sod/weights/swin224.pth', map_location=device)['model'],strict=False)
243
+ self.resnet.load_state_dict(torch.load('sod/weights/resnet18.pth', map_location=device),strict=False)
244
+
245
+ if self.cfg is not None and self.cfg.snapshot:
246
+ print('load checkpoint')
247
+ pretrain=torch.load(self.cfg.snapshot, map_location=device)
248
+ new_state_dict = {}
249
+ for k,v in pretrain.items():
250
+ new_state_dict[k[7:]] = v
251
+ self.load_state_dict(new_state_dict, strict=False)
252
+
253
+ def forward(self, x,shape=None,mask=None):
254
+ shape = x.size()[2:] if shape is None else shape
255
+ y = F.interpolate(x, size=(224,224), mode='bilinear',align_corners=True)
256
+
257
+ r2,r3,r4,r5 = self.resnet(x)
258
+ s1,s2,s3,s4 = self.swin(y)
259
+ r2_,attmap,r5_,s2_ = self.decoder(s1,s2,s3,s4,r2,r3,r4,r5)
260
+
261
+ pred1 = F.interpolate(self.linear1(r2_), size=shape, mode='bilinear')
262
+ wr = F.interpolate(self.linear2(r5_), size=(28,28), mode='bilinear')
263
+ ws = F.interpolate(self.linear3(s2_), size=(28,28), mode='bilinear')
264
+
265
+
266
+ return pred1,wr,ws,self.conv(attmap)
267
+
268
+
269
+
270
+
sod/Res.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
6
+ 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
7
+ 'wide_resnet50_2', 'wide_resnet101_2']
8
+
9
+
10
+ model_urls = {
11
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
12
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
13
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
14
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
15
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
16
+ 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
17
+ 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
18
+ 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
19
+ 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
20
+ }
21
+
22
+
23
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
24
+ """3x3 convolution with padding"""
25
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
26
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
27
+
28
+
29
+ def conv1x1(in_planes, out_planes, stride=1):
30
+ """1x1 convolution"""
31
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
32
+
33
+
34
+ class BasicBlock(nn.Module):
35
+ expansion = 1
36
+
37
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
38
+ base_width=64, dilation=1, norm_layer=None):
39
+ super(BasicBlock, self).__init__()
40
+ if norm_layer is None:
41
+ norm_layer = nn.BatchNorm2d
42
+ if groups != 1 or base_width != 64:
43
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
44
+ if dilation > 1:
45
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
46
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
47
+ self.conv1 = conv3x3(inplanes, planes, stride)
48
+ self.bn1 = norm_layer(planes)
49
+ self.relu = nn.ReLU(inplace=True)
50
+ self.conv2 = conv3x3(planes, planes)
51
+ self.bn2 = norm_layer(planes)
52
+ self.downsample = downsample
53
+ self.stride = stride
54
+
55
+ def forward(self, x):
56
+ identity = x
57
+
58
+ out = self.conv1(x)
59
+ out = self.bn1(out)
60
+ out = self.relu(out)
61
+
62
+ out = self.conv2(out)
63
+ out = self.bn2(out)
64
+
65
+ if self.downsample is not None:
66
+ identity = self.downsample(x)
67
+
68
+ out += identity
69
+ out = self.relu(out)
70
+
71
+ return out
72
+ def initialize(self):
73
+ weight_init(self)
74
+
75
+
76
+ class Bottleneck(nn.Module):
77
+ # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
78
+ # while original implementation places the stride at the first 1x1 convolution(self.conv1)
79
+ # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
80
+ # This variant is also known as ResNet V1.5 and improves accuracy according to
81
+ # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
82
+
83
+ expansion = 4
84
+
85
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
86
+ base_width=64, dilation=1, norm_layer=None):
87
+ super(Bottleneck, self).__init__()
88
+ if norm_layer is None:
89
+ norm_layer = nn.BatchNorm2d
90
+ width = int(planes * (base_width / 64.)) * groups
91
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
92
+ self.conv1 = conv1x1(inplanes, width)
93
+ self.bn1 = norm_layer(width)
94
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
95
+ self.bn2 = norm_layer(width)
96
+ self.conv3 = conv1x1(width, planes * self.expansion)
97
+ self.bn3 = norm_layer(planes * self.expansion)
98
+ self.relu = nn.ReLU(inplace=True)
99
+ self.downsample = downsample
100
+ self.stride = stride
101
+
102
+ def forward(self, x):
103
+ identity = x
104
+
105
+ out = self.conv1(x)
106
+ out = self.bn1(out)
107
+ out = self.relu(out)
108
+
109
+ out = self.conv2(out)
110
+ out = self.bn2(out)
111
+ out = self.relu(out)
112
+
113
+ out = self.conv3(out)
114
+ out = self.bn3(out)
115
+
116
+ if self.downsample is not None:
117
+ identity = self.downsample(x)
118
+
119
+ out += identity
120
+ out = self.relu(out)
121
+
122
+ return out
123
+ def initialize(self):
124
+ weight_init(self)
125
+
126
+ def weight_init(module):
127
+ for n, m in module.named_children():
128
+ # print('initialize: '+n)
129
+ if isinstance(m, nn.Conv2d):
130
+ nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
131
+ if m.bias is not None:
132
+ nn.init.zeros_(m.bias)
133
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d)):
134
+ nn.init.ones_(m.weight)
135
+ if m.bias is not None:
136
+ nn.init.zeros_(m.bias)
137
+ elif isinstance(m, nn.Linear):
138
+ nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
139
+ if m.bias is not None:
140
+ nn.init.zeros_(m.bias)
141
+ elif isinstance(m, nn.Sequential):
142
+ weight_init(m)
143
+ elif isinstance(m, (nn.ReLU,nn.AdaptiveAvgPool2d,nn.Softmax,nn.MaxPool2d)):
144
+ pass
145
+ else:
146
+ m.initialize()
147
+
148
+
149
+ class ResNet(nn.Module):
150
+
151
+ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
152
+ groups=1, width_per_group=64, replace_stride_with_dilation=None,
153
+ norm_layer=None):
154
+ super(ResNet, self).__init__()
155
+ if norm_layer is None:
156
+ norm_layer = nn.BatchNorm2d
157
+ self._norm_layer = norm_layer
158
+
159
+ self.inplanes = 64
160
+ self.dilation = 1
161
+ if replace_stride_with_dilation is None:
162
+ # each element in the tuple indicates if we should replace
163
+ # the 2x2 stride with a dilated convolution instead
164
+ replace_stride_with_dilation = [False, False, False]
165
+ if len(replace_stride_with_dilation) != 3:
166
+ raise ValueError("replace_stride_with_dilation should be None "
167
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
168
+ self.groups = groups
169
+ self.base_width = width_per_group
170
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
171
+ bias=False)
172
+ self.bn1 = norm_layer(self.inplanes)
173
+ self.relu = nn.ReLU(inplace=True)
174
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
175
+ self.layer1 = self._make_layer(block, 64, layers[0])
176
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
177
+ dilate=replace_stride_with_dilation[0])
178
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
179
+ dilate=replace_stride_with_dilation[1])
180
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
181
+ dilate=replace_stride_with_dilation[2])
182
+
183
+ for m in self.modules():
184
+ if isinstance(m, nn.Conv2d):
185
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
186
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
187
+ nn.init.constant_(m.weight, 1)
188
+ nn.init.constant_(m.bias, 0)
189
+
190
+ # Zero-initialize the last BN in each residual branch,
191
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
192
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
193
+ if zero_init_residual:
194
+ for m in self.modules():
195
+ if isinstance(m, Bottleneck):
196
+ nn.init.constant_(m.bn3.weight, 0)
197
+ elif isinstance(m, BasicBlock):
198
+ nn.init.constant_(m.bn2.weight, 0)
199
+
200
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
201
+ norm_layer = self._norm_layer
202
+ downsample = None
203
+ previous_dilation = self.dilation
204
+ if dilate:
205
+ self.dilation *= stride
206
+ stride = 1
207
+ if stride != 1 or self.inplanes != planes * block.expansion:
208
+ downsample = nn.Sequential(
209
+ conv1x1(self.inplanes, planes * block.expansion, stride),
210
+ norm_layer(planes * block.expansion),
211
+ )
212
+
213
+ layers = []
214
+ layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
215
+ self.base_width, previous_dilation, norm_layer))
216
+ self.inplanes = planes * block.expansion
217
+ for _ in range(1, blocks):
218
+ layers.append(block(self.inplanes, planes, groups=self.groups,
219
+ base_width=self.base_width, dilation=self.dilation,
220
+ norm_layer=norm_layer))
221
+
222
+ return nn.Sequential(*layers)
223
+
224
+ def forward(self, x):
225
+ out1 = F.relu(self.bn1(self.conv1(x)),inplace=True)
226
+ out1 = self.maxpool(out1)
227
+ out2 = self.layer1(out1)
228
+ out3 = self.layer2(out2)
229
+ out4 = self.layer3(out3)
230
+ out5 = self.layer4(out4)
231
+ return out2, out3, out4, out5
232
+ def initialize(self):
233
+ weight_init(self)
234
+
235
+
236
+ def _resnet(arch, block, layers, pretrained, progress, **kwargs):
237
+ model = ResNet(block, layers, **kwargs)
238
+
239
+ return model
240
+
241
+
242
+ def resnet18(pretrained=False, progress=True, **kwargs):
243
+ r"""ResNet-18 model from
244
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
245
+
246
+ Args:
247
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
248
+ progress (bool): If True, displays a progress bar of the download to stderr
249
+ """
250
+ return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
251
+ **kwargs)
252
+
253
+
254
+ def resnet34(pretrained=False, progress=True, **kwargs):
255
+ r"""ResNet-34 model from
256
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
257
+
258
+ Args:
259
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
260
+ progress (bool): If True, displays a progress bar of the download to stderr
261
+ """
262
+ return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
263
+ **kwargs)
264
+
265
+
266
+ def resnet50(pretrained=False, progress=True, **kwargs):
267
+ r"""ResNet-50 model from
268
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
269
+
270
+ Args:
271
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
272
+ progress (bool): If True, displays a progress bar of the download to stderr
273
+ """
274
+ return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
275
+ **kwargs)
276
+
277
+
278
+ def resnet101(pretrained=False, progress=True, **kwargs):
279
+ r"""ResNet-101 model from
280
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
281
+
282
+ Args:
283
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
284
+ progress (bool): If True, displays a progress bar of the download to stderr
285
+ """
286
+ return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
287
+ **kwargs)
288
+
289
+
290
+ def resnet152(pretrained=False, progress=True, **kwargs):
291
+ r"""ResNet-152 model from
292
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
293
+
294
+ Args:
295
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
296
+ progress (bool): If True, displays a progress bar of the download to stderr
297
+ """
298
+ return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
299
+ **kwargs)
300
+
301
+
302
+ def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
303
+ r"""ResNeXt-50 32x4d model from
304
+ `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
305
+
306
+ Args:
307
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
308
+ progress (bool): If True, displays a progress bar of the download to stderr
309
+ """
310
+ kwargs['groups'] = 32
311
+ kwargs['width_per_group'] = 4
312
+ return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
313
+ pretrained, progress, **kwargs)
314
+
315
+
316
+ def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
317
+ r"""ResNeXt-101 32x8d model from
318
+ `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
319
+
320
+ Args:
321
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
322
+ progress (bool): If True, displays a progress bar of the download to stderr
323
+ """
324
+ kwargs['groups'] = 32
325
+ kwargs['width_per_group'] = 8
326
+ return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
327
+ pretrained, progress, **kwargs)
328
+
329
+
330
+ def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
331
+ r"""Wide ResNet-50-2 model from
332
+ `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
333
+
334
+ The model is the same as ResNet except for the bottleneck number of channels
335
+ which is twice larger in every block. The number of channels in outer 1x1
336
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
337
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
338
+
339
+ Args:
340
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
341
+ progress (bool): If True, displays a progress bar of the download to stderr
342
+ """
343
+ kwargs['width_per_group'] = 64 * 2
344
+ return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
345
+ pretrained, progress, **kwargs)
346
+
347
+
348
+ def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
349
+ r"""Wide ResNet-101-2 model from
350
+ `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
351
+
352
+ The model is the same as ResNet except for the bottleneck number of channels
353
+ which is twice larger in every block. The number of channels in outer 1x1
354
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
355
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
356
+
357
+ Args:
358
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
359
+ progress (bool): If True, displays a progress bar of the download to stderr
360
+ """
361
+ kwargs['width_per_group'] = 64 * 2
362
+ return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
363
+ pretrained, progress, **kwargs)
sod/Swin.py ADDED
@@ -0,0 +1,578 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.utils.checkpoint as checkpoint
4
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
5
+
6
+ class Mlp(nn.Module):
7
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
8
+ super().__init__()
9
+ out_features = out_features or in_features
10
+ hidden_features = hidden_features or in_features
11
+ self.fc1 = nn.Linear(in_features, hidden_features)
12
+ self.act = act_layer()
13
+ self.fc2 = nn.Linear(hidden_features, out_features)
14
+ self.drop = nn.Dropout(drop)
15
+
16
+ def forward(self, x):
17
+ x = self.fc1(x)
18
+ x = self.act(x)
19
+ x = self.drop(x)
20
+ x = self.fc2(x)
21
+ x = self.drop(x)
22
+ return x
23
+
24
+
25
+ def window_partition(x, window_size):
26
+ """
27
+ Args:
28
+ x: (B, H, W, C)
29
+ window_size (int): window size
30
+
31
+ Returns:
32
+ windows: (num_windows*B, window_size, window_size, C)
33
+ """
34
+ B, H, W, C = x.shape
35
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
36
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
37
+ return windows
38
+
39
+
40
+ def window_reverse(windows, window_size, H, W):
41
+ """
42
+ Args:
43
+ windows: (num_windows*B, window_size, window_size, C)
44
+ window_size (int): Window size
45
+ H (int): Height of image
46
+ W (int): Width of image
47
+
48
+ Returns:
49
+ x: (B, H, W, C)
50
+ """
51
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
52
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
53
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
54
+ return x
55
+
56
+
57
+ class WindowAttention(nn.Module):
58
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
59
+ It supports both of shifted and non-shifted window.
60
+
61
+ Args:
62
+ dim (int): Number of input channels.
63
+ window_size (tuple[int]): The height and width of the window.
64
+ num_heads (int): Number of attention heads.
65
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
66
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
67
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
68
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
69
+ """
70
+
71
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
72
+
73
+ super().__init__()
74
+ self.dim = dim
75
+ self.window_size = window_size # Wh, Ww
76
+ self.num_heads = num_heads
77
+ head_dim = dim // num_heads
78
+ self.scale = qk_scale or head_dim ** -0.5
79
+
80
+ # define a parameter table of relative position bias
81
+ self.relative_position_bias_table = nn.Parameter(
82
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
83
+
84
+ # get pair-wise relative position index for each token inside the window
85
+ coords_h = torch.arange(self.window_size[0])
86
+ coords_w = torch.arange(self.window_size[1])
87
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
88
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
89
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
90
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
91
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
92
+ relative_coords[:, :, 1] += self.window_size[1] - 1
93
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
94
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
95
+ self.register_buffer("relative_position_index", relative_position_index)
96
+
97
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
98
+ self.attn_drop = nn.Dropout(attn_drop)
99
+ self.proj = nn.Linear(dim, dim)
100
+ self.proj_drop = nn.Dropout(proj_drop)
101
+
102
+ trunc_normal_(self.relative_position_bias_table, std=.02)
103
+ self.softmax = nn.Softmax(dim=-1)
104
+
105
+ def forward(self, x, mask=None):
106
+ """
107
+ Args:
108
+ x: input features with shape of (num_windows*B, N, C)
109
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
110
+ """
111
+ B_, N, C = x.shape
112
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
113
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
114
+
115
+ q = q * self.scale
116
+ attn = (q @ k.transpose(-2, -1))
117
+
118
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
119
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
120
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
121
+ attn = attn + relative_position_bias.unsqueeze(0)
122
+
123
+ if mask is not None:
124
+ nW = mask.shape[0]
125
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
126
+ attn = attn.view(-1, self.num_heads, N, N)
127
+ attn = self.softmax(attn)
128
+ else:
129
+ attn = self.softmax(attn)
130
+
131
+ attn = self.attn_drop(attn)
132
+
133
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
134
+ x = self.proj(x)
135
+ x = self.proj_drop(x)
136
+ return x
137
+
138
+ def extra_repr(self) -> str:
139
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
140
+
141
+ def flops(self, N):
142
+ # calculate flops for 1 window with token length of N
143
+ flops = 0
144
+ # qkv = self.qkv(x)
145
+ flops += N * self.dim * 3 * self.dim
146
+ # attn = (q @ k.transpose(-2, -1))
147
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
148
+ # x = (attn @ v)
149
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
150
+ # x = self.proj(x)
151
+ flops += N * self.dim * self.dim
152
+ return flops
153
+
154
+
155
+ class SwinTransformerBlock(nn.Module):
156
+ r""" Swin Transformer Block.
157
+
158
+ Args:
159
+ dim (int): Number of input channels.
160
+ input_resolution (tuple[int]): Input resulotion.
161
+ num_heads (int): Number of attention heads.
162
+ window_size (int): Window size.
163
+ shift_size (int): Shift size for SW-MSA.
164
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
165
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
166
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
167
+ drop (float, optional): Dropout rate. Default: 0.0
168
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
169
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
170
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
171
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
172
+ """
173
+
174
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
175
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
176
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
177
+ super().__init__()
178
+ self.dim = dim
179
+ self.input_resolution = input_resolution
180
+ self.num_heads = num_heads
181
+ self.window_size = window_size
182
+ self.shift_size = shift_size
183
+ self.mlp_ratio = mlp_ratio
184
+ if min(self.input_resolution) <= self.window_size:
185
+ # if window size is larger than input resolution, we don't partition windows
186
+ self.shift_size = 0
187
+ self.window_size = min(self.input_resolution)
188
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
189
+
190
+ self.norm1 = norm_layer(dim)
191
+ self.attn = WindowAttention(
192
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
193
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
194
+
195
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
196
+ self.norm2 = norm_layer(dim)
197
+ mlp_hidden_dim = int(dim * mlp_ratio)
198
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
199
+
200
+ if self.shift_size > 0:
201
+ # calculate attention mask for SW-MSA
202
+ H, W = self.input_resolution
203
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
204
+ h_slices = (slice(0, -self.window_size),
205
+ slice(-self.window_size, -self.shift_size),
206
+ slice(-self.shift_size, None))
207
+ w_slices = (slice(0, -self.window_size),
208
+ slice(-self.window_size, -self.shift_size),
209
+ slice(-self.shift_size, None))
210
+ cnt = 0
211
+ for h in h_slices:
212
+ for w in w_slices:
213
+ img_mask[:, h, w, :] = cnt
214
+ cnt += 1
215
+
216
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
217
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
218
+ atten_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
219
+ atten_mask = atten_mask.masked_fill(atten_mask != 0, float(-100.0)).masked_fill(atten_mask == 0, float(0.0))
220
+ else:
221
+ atten_mask = None
222
+
223
+ self.register_buffer("atten_mask", atten_mask)
224
+
225
+ def forward(self, x):
226
+ H, W = self.input_resolution
227
+ B, L, C = x.shape
228
+ assert L == H * W, "input feature has wrong size"
229
+
230
+ shortcut = x
231
+ x = self.norm1(x)
232
+ x = x.view(B, H, W, C)
233
+
234
+ # cyclic shift
235
+ if self.shift_size > 0:
236
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
237
+ else:
238
+ shifted_x = x
239
+
240
+ # partition windows
241
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
242
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
243
+
244
+ # W-MSA/SW-MSA
245
+ attn_windows = self.attn(x_windows, mask=self.atten_mask) # nW*B, window_size*window_size, C
246
+
247
+ # merge windows
248
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
249
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
250
+
251
+ # reverse cyclic shift
252
+ if self.shift_size > 0:
253
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
254
+ else:
255
+ x = shifted_x
256
+ x = x.view(B, H * W, C)
257
+
258
+ # FFN
259
+ x = shortcut + self.drop_path(x)
260
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
261
+
262
+ return x
263
+
264
+ def extra_repr(self) -> str:
265
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
266
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
267
+
268
+ def flops(self):
269
+ flops = 0
270
+ H, W = self.input_resolution
271
+ # norm1
272
+ flops += self.dim * H * W
273
+ # W-MSA/SW-MSA
274
+ nW = H * W / self.window_size / self.window_size
275
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
276
+ # mlp
277
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
278
+ # norm2
279
+ flops += self.dim * H * W
280
+ return flops
281
+
282
+
283
+ class PatchMerging(nn.Module):
284
+ r""" Patch Merging Layer.
285
+
286
+ Args:
287
+ input_resolution (tuple[int]): Resolution of input feature.
288
+ dim (int): Number of input channels.
289
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
290
+ """
291
+
292
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
293
+ super().__init__()
294
+ self.input_resolution = input_resolution
295
+ self.dim = dim
296
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
297
+ self.norm = norm_layer(4 * dim)
298
+
299
+ def forward(self, x):
300
+ """
301
+ x: B, H*W, C
302
+ """
303
+ H, W = self.input_resolution
304
+ B, L, C = x.shape
305
+ assert L == H * W, "input feature has wrong size"
306
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
307
+
308
+ x = x.view(B, H, W, C)
309
+
310
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
311
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
312
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
313
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
314
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
315
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
316
+
317
+ x = self.norm(x)
318
+ x = self.reduction(x)
319
+
320
+ return x
321
+
322
+ def extra_repr(self) -> str:
323
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
324
+
325
+ def flops(self):
326
+ H, W = self.input_resolution
327
+ flops = H * W * self.dim
328
+ flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
329
+ return flops
330
+
331
+
332
+ class BasicLayer(nn.Module):
333
+ """ A basic Swin Transformer layer for one stage.
334
+
335
+ Args:
336
+ dim (int): Number of input channels.
337
+ input_resolution (tuple[int]): Input resolution.
338
+ depth (int): Number of blocks.
339
+ num_heads (int): Number of attention heads.
340
+ window_size (int): Local window size.
341
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
342
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
343
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
344
+ drop (float, optional): Dropout rate. Default: 0.0
345
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
346
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
347
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
348
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
349
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
350
+ """
351
+
352
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
353
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
354
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
355
+
356
+ super().__init__()
357
+ self.dim = dim
358
+ self.input_resolution = input_resolution
359
+ self.depth = depth
360
+ self.use_checkpoint = use_checkpoint
361
+
362
+ # build blocks
363
+ self.blocks = nn.ModuleList([
364
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
365
+ num_heads=num_heads, window_size=window_size,
366
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
367
+ mlp_ratio=mlp_ratio,
368
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
369
+ drop=drop, attn_drop=attn_drop,
370
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
371
+ norm_layer=norm_layer)
372
+ for i in range(depth)])
373
+
374
+ # patch merging layer
375
+ if downsample is not None:
376
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
377
+ else:
378
+ self.downsample = None
379
+
380
+ def forward(self, x):
381
+ for blk in self.blocks:
382
+ if self.use_checkpoint:
383
+ x = checkpoint.checkpoint(blk, x)
384
+ else:
385
+ x = blk(x)
386
+ if self.downsample is not None:
387
+ x = self.downsample(x)
388
+ return x
389
+
390
+ def extra_repr(self) -> str:
391
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
392
+
393
+ def flops(self):
394
+ flops = 0
395
+ for blk in self.blocks:
396
+ flops += blk.flops()
397
+ if self.downsample is not None:
398
+ flops += self.downsample.flops()
399
+ return flops
400
+
401
+
402
+ class PatchEmbed(nn.Module):
403
+ r""" Image to Patch Embedding
404
+
405
+ Args:
406
+ img_size (int): Image size. Default: 224.
407
+ patch_size (int): Patch token size. Default: 4.
408
+ in_chans (int): Number of input image channels. Default: 3.
409
+ embed_dim (int): Number of linear projection output channels. Default: 96.
410
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
411
+ """
412
+
413
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
414
+ super().__init__()
415
+ img_size = to_2tuple(img_size)
416
+ patch_size = to_2tuple(patch_size)
417
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
418
+ self.img_size = img_size
419
+ self.patch_size = patch_size
420
+ self.patches_resolution = patches_resolution
421
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
422
+
423
+ self.in_chans = in_chans
424
+ self.embed_dim = embed_dim
425
+
426
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
427
+ if norm_layer is not None:
428
+ self.norm = norm_layer(embed_dim)
429
+ else:
430
+ self.norm = None
431
+
432
+ def forward(self, x):
433
+ B, C, H, W = x.shape
434
+ # FIXME look at relaxing size constraints
435
+ assert H == self.img_size[0] and W == self.img_size[1], \
436
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
437
+ x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
438
+ if self.norm is not None:
439
+ x = self.norm(x)
440
+ return x
441
+
442
+ def flops(self):
443
+ Ho, Wo = self.patches_resolution
444
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
445
+ if self.norm is not None:
446
+ flops += Ho * Wo * self.embed_dim
447
+ return flops
448
+
449
+
450
+ class Swintransformer(nn.Module):
451
+ r""" Swin Transformer
452
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
453
+ https://arxiv.org/pdf/2103.14030
454
+
455
+ Args:
456
+ img_size (int | tuple(int)): Input image size. Default 224
457
+ patch_size (int | tuple(int)): Patch size. Default: 4
458
+ in_chans (int): Number of input image channels. Default: 3
459
+ num_classes (int): Number of classes for classification head. Default: 1000
460
+ embed_dim (int): Patch embedding dimension. Default: 96
461
+ depths (tuple(int)): Depth of each Swin Transformer layer.
462
+ num_heads (tuple(int)): Number of attention heads in different layers.
463
+ window_size (int): Window size. Default: 7
464
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
465
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
466
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
467
+ drop_rate (float): Dropout rate. Default: 0
468
+ attn_drop_rate (float): Attention dropout rate. Default: 0
469
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
470
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
471
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
472
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
473
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
474
+ """
475
+
476
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
477
+ embed_dim=128, depths=[2, 2, 18,2], num_heads=[4, 8, 16, 32],
478
+ window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
479
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.5,
480
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
481
+ use_checkpoint=False, **kwargs):
482
+ super().__init__()
483
+
484
+ self.num_classes = num_classes
485
+ self.num_layers = len(depths)
486
+ self.embed_dim = embed_dim
487
+ self.ape = ape
488
+ self.patch_norm = patch_norm
489
+ self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
490
+ self.mlp_ratio = mlp_ratio
491
+
492
+ # split image into non-overlapping patches
493
+ self.patch_embed = PatchEmbed(
494
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
495
+ norm_layer=norm_layer if self.patch_norm else None)
496
+ num_patches = self.patch_embed.num_patches
497
+ patches_resolution = self.patch_embed.patches_resolution
498
+ self.patches_resolution = patches_resolution
499
+
500
+ # absolute position embedding
501
+ if self.ape:
502
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
503
+ trunc_normal_(self.absolute_pos_embed, std=.02)
504
+
505
+ self.pos_drop = nn.Dropout(p=drop_rate)
506
+
507
+ # stochastic depth
508
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
509
+
510
+ # build layers
511
+ self.layers = nn.ModuleList()
512
+ for i_layer in range(self.num_layers-1):
513
+ layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
514
+ input_resolution=(patches_resolution[0] // (2 ** i_layer ),
515
+ patches_resolution[1] // (2 ** i_layer )),
516
+ depth=depths[i_layer],
517
+ num_heads=num_heads[i_layer],
518
+ window_size=window_size,
519
+ mlp_ratio=self.mlp_ratio,
520
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
521
+ drop=drop_rate, attn_drop=attn_drop_rate,
522
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
523
+ norm_layer=norm_layer,
524
+ downsample=PatchMerging if (i_layer < self.num_layers - 2) else None,
525
+ use_checkpoint=use_checkpoint)
526
+ self.layers.append(layer)
527
+
528
+ self.norm = norm_layer(self.num_features)
529
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
530
+
531
+ self.apply(self._init_weights)
532
+
533
+ def _init_weights(self, m):
534
+ if isinstance(m, nn.Linear):
535
+ trunc_normal_(m.weight, std=.02)
536
+ if isinstance(m, nn.Linear) and m.bias is not None:
537
+ nn.init.constant_(m.bias, 0)
538
+ elif isinstance(m, nn.LayerNorm):
539
+ nn.init.constant_(m.bias, 0)
540
+ nn.init.constant_(m.weight, 1.0)
541
+
542
+ @torch.jit.ignore
543
+ def no_weight_decay(self):
544
+ return {'absolute_pos_embed'}
545
+
546
+ @torch.jit.ignore
547
+ def no_weight_decay_keywords(self):
548
+ return {'relative_position_bias_table'}
549
+
550
+ def forward_features(self, x):
551
+ b,c,h,w = x.shape
552
+ x = self.patch_embed(x)
553
+
554
+ if self.ape:
555
+ x = x + self.absolute_pos_embed
556
+ x = self.pos_drop(x)
557
+ s = []
558
+ s.append(x.view(b, int((x.shape[1])**0.5),int((x.shape[1])**(0.5)), -1).permute(0, 3, 1, 2).contiguous())
559
+ for layer in self.layers:
560
+ x = layer(x)
561
+ s.append(x.view(b, int((x.shape[1])**0.5),int((x.shape[1])**(0.5)), -1).permute(0, 3, 1, 2).contiguous())
562
+
563
+ # x = self.norm(x) # B L C
564
+ # x = self.avgpool(x.transpose(1, 2)) # B C 1
565
+ # x = torch.flatten(x, 1)
566
+ return s
567
+
568
+ def forward(self, x):
569
+ x = self.forward_features(x)
570
+ # x = self.head(x)
571
+ return x
572
+
573
+ def flops(self):
574
+ flops = 0
575
+ flops += self.patch_embed.flops()
576
+ for i, layer in enumerate(self.layers):
577
+ flops += layer.flops()
578
+ return flops
sod/configs/prediction/default.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ indir: no # to be overriden in CLI
2
+ outdir: no # to be overriden in CLI
3
+
4
+ model:
5
+ path: no # to be overriden in CLI
6
+ checkpoint: best.ckpt
7
+
8
+ dataset:
9
+ kind: default
10
+ img_suffix: .png
11
+ pad_out_to_modulo: 8
12
+
13
+ device: cuda
14
+ out_key: inpainted
sod/infer_model.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ import sys
8
+ sys.path.insert(0, '../')
9
+ sys.dont_write_bytecode = True
10
+ from .PGNet import PGNet
11
+
12
+ class Normalize(object):
13
+ def __init__(self, mean, std):
14
+ self.mean = mean
15
+ self.std = std
16
+
17
+ def __call__(self, image):
18
+ image = (image - self.mean)/self.std
19
+ return image
20
+
21
+ class Config(object):
22
+ def __init__(self, **kwargs):
23
+ self.kwargs = kwargs
24
+ self.mean = np.array([[[124.55, 118.90, 102.94]]])
25
+ self.std = np.array([[[ 56.77, 55.97, 57.50]]])
26
+ print('\nParameters...')
27
+ for k, v in self.kwargs.items():
28
+ print('%-10s: %s'%(k, v))
29
+
30
+ def __getattr__(self, name):
31
+ if name in self.kwargs:
32
+ return self.kwargs[name]
33
+ else:
34
+ return None
35
+
36
+ class IVModel():
37
+ def __init__(self, device=torch.device('cuda:0')):
38
+ super(IVModel, self).__init__()
39
+ self.device = device
40
+ checkpoint_path = 'sod/weights/PGNet_DUT+HR-model-31.pth'
41
+ self.cfg = Config(snapshot=checkpoint_path, mode='test')
42
+ if not os.path.exists(checkpoint_path):
43
+ print('未找到模型文件!')
44
+ self.net = PGNet(self.cfg)
45
+ self.net.train(False)
46
+ self.net.to(device)
47
+ self.normalize = Normalize(mean=self.cfg.mean, std=self.cfg.std)
48
+
49
+ self.__first_forward__()
50
+
51
+
52
+ def __first_forward__(self, input_size=(2048, 2048, 3)):
53
+ # 调用forward()严格控制最大显存
54
+ print('initialize Sod Model...')
55
+ _ = self.forward(np.random.rand(*input_size) * 255, None)
56
+ print('initialize Complete!')
57
+
58
+ def __resize_tensor__(self, image, max_size=1024):
59
+ h, w = image.size()[2:]
60
+ if max(h, w) > max_size:
61
+ if h < w:
62
+ h, w = int(max_size * h / w)//8*8, max_size
63
+ else:
64
+ h, w = max_size, int(max_size * w / h)//8*8
65
+ image = F.interpolate(image, (h, w), mode='area')
66
+ return image
67
+
68
+ def input_preprocess_tensor(self, img):
69
+ img = self.normalize(img)
70
+ img_t = torch.from_numpy(img.astype(np.float32)) # .to(self.device)
71
+ img_t = img_t.permute(2, 0, 1).unsqueeze(0)
72
+ img_t = self.__resize_tensor__(img_t).to(self.device) # 为了控制最大显存容量
73
+ return img_t
74
+
75
+ def forward(self, img, json_data):
76
+ img_t = self.input_preprocess_tensor(img)
77
+ shape = [torch.as_tensor([img_t.shape[2]]), torch.as_tensor([img_t.shape[3]])]
78
+ h, w = img_t.shape[2], img_t.shape[3]
79
+ img_t_temp = F.interpolate(img_t, (1024, 1024), mode='area')
80
+ with torch.no_grad():
81
+ res = self.net(img_t_temp, shape=shape)
82
+ res = F.interpolate(res[0],size=shape, mode='bilinear')
83
+ res = torch.sigmoid(res)
84
+ # print(res.shape, img_t.shape, res.expand_as(img_t).shape)
85
+ res = torch.cat([img_t, res.expand_as(img_t)], dim=3)
86
+ res = (res[0].permute(1,2,0)).cpu().numpy()
87
+ res[:,:w,:] = res[:,:w,:] * self.cfg.std + self.cfg.mean
88
+ res[:,w:,:] = res[:,w:,:] * 255
89
+ return res