juba7 commited on
Commit
c17cd98
·
verified ·
1 Parent(s): ad9a05a

Upload 9 files

Browse files
Files changed (9) hide show
  1. .gitattributes +43 -35
  2. BEN_Base.pth +3 -0
  3. README.md +109 -0
  4. config.json +6 -0
  5. demo.jpg +3 -0
  6. image.png +3 -0
  7. inference.py +17 -0
  8. model.py +951 -0
  9. requirements.txt +6 -0
.gitattributes CHANGED
@@ -1,35 +1,43 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ 027.jpg filter=lfs diff=lfs merge=lfs -text
37
+ 027.png filter=lfs diff=lfs merge=lfs -text
38
+ 4340.png filter=lfs diff=lfs merge=lfs -text
39
+ 5442.jpg filter=lfs diff=lfs merge=lfs -text
40
+ 5442.png filter=lfs diff=lfs merge=lfs -text
41
+ image.jpg filter=lfs diff=lfs merge=lfs -text
42
+ image.png filter=lfs diff=lfs merge=lfs -text
43
+ demo.jpg filter=lfs diff=lfs merge=lfs -text
BEN_Base.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f00a5804c96afed6e19be09dbbbe56ccaf82cff3d751f44aadb9626b77facfa
3
+ size 1134588350
README.md CHANGED
@@ -1,3 +1,112 @@
1
  ---
2
  license: apache-2.0
 
 
 
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
+ pipeline_tag: image-segmentation
4
+ tags:
5
+ - BEN
6
+ - background-remove
7
+ - mask-generation
8
+ - Dichotomous image segmentation
9
+ - background remove
10
+ - foreground
11
+ - background
12
+ - remove background
13
+ - pytorch
14
  ---
15
+
16
+ # BEN: Background Erase Network
17
+
18
+ [![arXiv](https://img.shields.io/badge/arXiv-2501.06230-b31b1b.svg)](https://arxiv.org/abs/2501.06230)
19
+ [![GitHub](https://img.shields.io/badge/GitHub-BEN-black.svg)](https://github.com/PramaLLC/BEN/)
20
+ [![Website](https://img.shields.io/badge/Website-backgrounderase.net-104233)](https://backgrounderase.net)
21
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/ben-using-confidence-guided-matting-for/dichotomous-image-segmentation-on-dis-vd)](https://paperswithcode.com/sota/dichotomous-image-segmentation-on-dis-vd?p=ben-using-confidence-guided-matting-for)
22
+
23
+ ## Overview
24
+ BEN (Background Erase Network) introduces a novel approach to foreground segmentation through its innovative Confidence Guided Matting (CGM) pipeline. The architecture employs a refiner network that targets and processes pixels where the base model exhibits lower confidence levels, resulting in more precise and reliable matting results.
25
+
26
+ This repository provides the official code for our model, as detailed in our research paper: [BEN: Background Erase Network](https://arxiv.org/abs/2501.06230).
27
+
28
+
29
+
30
+ ## BEN2 Access
31
+ BEN2 is now publicly available, trained on DIS5k and our 22K proprietary segmentation dataset. Our enhanced model delivers superior performance in hair matting, 4K processing, object segmentation, and edge refinement. Access the base model on Huggingface, try the full model through our free web demo or integrate BEN2 into your project with our API:
32
+ - 🤗 [PramaLLC/BEN2](https://huggingface.co/PramaLLC/BEN2)
33
+ - 🌐 [backgrounderase.net](https://backgrounderase.net)
34
+
35
+ ## Model Access
36
+ The base model is publicly available and free to use for commercial use on HuggingFace:
37
+ - 🤗 [PramaLLC/BEN](https://huggingface.co/PramaLLC/BEN)
38
+
39
+
40
+ ## Contact US
41
+ - For access to our commercial model email us at [email protected]
42
+ - Our website: https://pramadevelopment.com/
43
+ - Follow us on X: https://x.com/PramaResearch/
44
+
45
+
46
+ ## Quick Start Code (Inside Cloned Repo)
47
+
48
+ ```python
49
+ import model
50
+ from PIL import Image
51
+ import torch
52
+
53
+
54
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
55
+
56
+ file = "./image.png" # input image
57
+
58
+ model = model.BEN_Base().to(device).eval() #init pipeline
59
+
60
+ model.loadcheckpoints("./BEN_Base.pth")
61
+ image = Image.open(file)
62
+ mask, foreground = model.inference(image)
63
+
64
+ mask.save("./mask.png")
65
+ foreground.save("./foreground.png")
66
+ ```
67
+
68
+ # BEN SOA Benchmarks on Disk 5k Eval
69
+
70
+ ![Demo Results](demo.jpg)
71
+
72
+
73
+ ### BEN_Base + BEN_Refiner (commercial model please contact us for more information):
74
+ - MAE: 0.0270
75
+ - DICE: 0.8989
76
+ - IOU: 0.8506
77
+ - BER: 0.0496
78
+ - ACC: 0.9740
79
+
80
+ ### BEN_Base (94 million parameters):
81
+ - MAE: 0.0309
82
+ - DICE: 0.8806
83
+ - IOU: 0.8371
84
+ - BER: 0.0516
85
+ - ACC: 0.9718
86
+
87
+ ### MVANet (old SOTA):
88
+ - MAE: 0.0353
89
+ - DICE: 0.8676
90
+ - IOU: 0.8104
91
+ - BER: 0.0639
92
+ - ACC: 0.9660
93
+
94
+
95
+ ### BiRefNet(not tested in house):
96
+ - MAE: 0.038
97
+
98
+
99
+ ### InSPyReNet (not tested in house):
100
+ - MAE: 0.042
101
+
102
+
103
+
104
+ ## Features
105
+ - Background removal from images
106
+ - Generates both binary mask and foreground image
107
+ - CUDA support for GPU acceleration
108
+ - Simple API for easy integration
109
+
110
+ ## Installation
111
+ 1. Clone Repo
112
+ 2. Install requirements.txt
config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "PramaLLC/BEN",
3
+ "architectures": ["PramaBEN_Base"],
4
+ "version": "1.0",
5
+ "torch_dtype": "float32"
6
+ }
demo.jpg ADDED

Git LFS Details

  • SHA256: 14d508822dff813f0d013bd59ed978c2bb11b162ab159b9ce1c8e2f6ad686394
  • Pointer size: 132 Bytes
  • Size of remote file: 1.83 MB
image.png ADDED

Git LFS Details

  • SHA256: 26327e963cca60854de738d2509d0fd913a2282acdfa05dc9b141393617dc59e
  • Pointer size: 132 Bytes
  • Size of remote file: 1.41 MB
inference.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import model
2
+ from PIL import Image
3
+ import torch
4
+
5
+
6
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
7
+
8
+ file = "./image.png" # input image
9
+
10
+ model = model.BEN_Base().to(device).eval() #init pipeline
11
+
12
+ model.loadcheckpoints("./BEN_Base.pth")
13
+ image = Image.open(file)
14
+ mask, foreground = model.inference(image)
15
+
16
+ mask.save("./mask.png")
17
+ foreground.save("./foreground.png")
model.py ADDED
@@ -0,0 +1,951 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torch.utils.checkpoint as checkpoint
7
+ from einops import rearrange
8
+ from PIL import Image, ImageFilter, ImageOps
9
+ from timm.layers import DropPath, to_2tuple, trunc_normal_
10
+ from torchvision import transforms
11
+
12
+ class Mlp(nn.Module):
13
+ """ Multilayer perceptron."""
14
+
15
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
16
+ super().__init__()
17
+ out_features = out_features or in_features
18
+ hidden_features = hidden_features or in_features
19
+ self.fc1 = nn.Linear(in_features, hidden_features)
20
+ self.act = act_layer()
21
+ self.fc2 = nn.Linear(hidden_features, out_features)
22
+ self.drop = nn.Dropout(drop)
23
+
24
+ def forward(self, x):
25
+ x = self.fc1(x)
26
+ x = self.act(x)
27
+ x = self.drop(x)
28
+ x = self.fc2(x)
29
+ x = self.drop(x)
30
+ return x
31
+
32
+
33
+ def window_partition(x, window_size):
34
+ """
35
+ Args:
36
+ x: (B, H, W, C)
37
+ window_size (int): window size
38
+ Returns:
39
+ windows: (num_windows*B, window_size, window_size, C)
40
+ """
41
+ B, H, W, C = x.shape
42
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
43
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
44
+ return windows
45
+
46
+
47
+ def window_reverse(windows, window_size, H, W):
48
+ """
49
+ Args:
50
+ windows: (num_windows*B, window_size, window_size, C)
51
+ window_size (int): Window size
52
+ H (int): Height of image
53
+ W (int): Width of image
54
+ Returns:
55
+ x: (B, H, W, C)
56
+ """
57
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
58
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
59
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
60
+ return x
61
+
62
+
63
+ class WindowAttention(nn.Module):
64
+ """ Window based multi-head self attention (W-MSA) module with relative position bias.
65
+ It supports both of shifted and non-shifted window.
66
+ Args:
67
+ dim (int): Number of input channels.
68
+ window_size (tuple[int]): The height and width of the window.
69
+ num_heads (int): Number of attention heads.
70
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
71
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
72
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
73
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
74
+ """
75
+
76
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
77
+
78
+ super().__init__()
79
+ self.dim = dim
80
+ self.window_size = window_size # Wh, Ww
81
+ self.num_heads = num_heads
82
+ head_dim = dim // num_heads
83
+ self.scale = qk_scale or head_dim ** -0.5
84
+
85
+ # define a parameter table of relative position bias
86
+ self.relative_position_bias_table = nn.Parameter(
87
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
88
+
89
+ # get pair-wise relative position index for each token inside the window
90
+ coords_h = torch.arange(self.window_size[0])
91
+ coords_w = torch.arange(self.window_size[1])
92
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
93
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
94
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
95
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
96
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
97
+ relative_coords[:, :, 1] += self.window_size[1] - 1
98
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
99
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
100
+ self.register_buffer("relative_position_index", relative_position_index)
101
+
102
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
103
+ self.attn_drop = nn.Dropout(attn_drop)
104
+ self.proj = nn.Linear(dim, dim)
105
+ self.proj_drop = nn.Dropout(proj_drop)
106
+
107
+ trunc_normal_(self.relative_position_bias_table, std=.02)
108
+ self.softmax = nn.Softmax(dim=-1)
109
+
110
+ def forward(self, x, mask=None):
111
+ """ Forward function.
112
+ Args:
113
+ x: input features with shape of (num_windows*B, N, C)
114
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
115
+ """
116
+ B_, N, C = x.shape
117
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
118
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
119
+
120
+ q = q * self.scale
121
+ attn = (q @ k.transpose(-2, -1))
122
+
123
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
124
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
125
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
126
+ attn = attn + relative_position_bias.unsqueeze(0)
127
+
128
+ if mask is not None:
129
+ nW = mask.shape[0]
130
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
131
+ attn = attn.view(-1, self.num_heads, N, N)
132
+ attn = self.softmax(attn)
133
+ else:
134
+ attn = self.softmax(attn)
135
+
136
+ attn = self.attn_drop(attn)
137
+
138
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
139
+ x = self.proj(x)
140
+ x = self.proj_drop(x)
141
+ return x
142
+
143
+
144
+ class SwinTransformerBlock(nn.Module):
145
+ """ Swin Transformer Block.
146
+ Args:
147
+ dim (int): Number of input channels.
148
+ num_heads (int): Number of attention heads.
149
+ window_size (int): Window size.
150
+ shift_size (int): Shift size for SW-MSA.
151
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
152
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
153
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
154
+ drop (float, optional): Dropout rate. Default: 0.0
155
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
156
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
157
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
158
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
159
+ """
160
+
161
+ def __init__(self, dim, num_heads, window_size=7, shift_size=0,
162
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
163
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
164
+ super().__init__()
165
+ self.dim = dim
166
+ self.num_heads = num_heads
167
+ self.window_size = window_size
168
+ self.shift_size = shift_size
169
+ self.mlp_ratio = mlp_ratio
170
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
171
+
172
+ self.norm1 = norm_layer(dim)
173
+ self.attn = WindowAttention(
174
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
175
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
176
+
177
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
178
+ self.norm2 = norm_layer(dim)
179
+ mlp_hidden_dim = int(dim * mlp_ratio)
180
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
181
+
182
+ self.H = None
183
+ self.W = None
184
+
185
+ def forward(self, x, mask_matrix):
186
+ """ Forward function.
187
+ Args:
188
+ x: Input feature, tensor size (B, H*W, C).
189
+ H, W: Spatial resolution of the input feature.
190
+ mask_matrix: Attention mask for cyclic shift.
191
+ """
192
+ B, L, C = x.shape
193
+ H, W = self.H, self.W
194
+ assert L == H * W, "input feature has wrong size"
195
+
196
+ shortcut = x
197
+ x = self.norm1(x)
198
+ x = x.view(B, H, W, C)
199
+
200
+ # pad feature maps to multiples of window size
201
+ pad_l = pad_t = 0
202
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
203
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
204
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
205
+ _, Hp, Wp, _ = x.shape
206
+
207
+ # cyclic shift
208
+ if self.shift_size > 0:
209
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
210
+ attn_mask = mask_matrix
211
+ else:
212
+ shifted_x = x
213
+ attn_mask = None
214
+
215
+ # partition windows
216
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
217
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
218
+
219
+ # W-MSA/SW-MSA
220
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
221
+
222
+ # merge windows
223
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
224
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
225
+
226
+ # reverse cyclic shift
227
+ if self.shift_size > 0:
228
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
229
+ else:
230
+ x = shifted_x
231
+
232
+ if pad_r > 0 or pad_b > 0:
233
+ x = x[:, :H, :W, :].contiguous()
234
+
235
+ x = x.view(B, H * W, C)
236
+
237
+ # FFN
238
+ x = shortcut + self.drop_path(x)
239
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
240
+
241
+ return x
242
+
243
+
244
+ class PatchMerging(nn.Module):
245
+ """ Patch Merging Layer
246
+ Args:
247
+ dim (int): Number of input channels.
248
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
249
+ """
250
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
251
+ super().__init__()
252
+ self.dim = dim
253
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
254
+ self.norm = norm_layer(4 * dim)
255
+
256
+ def forward(self, x, H, W):
257
+ """ Forward function.
258
+ Args:
259
+ x: Input feature, tensor size (B, H*W, C).
260
+ H, W: Spatial resolution of the input feature.
261
+ """
262
+ B, L, C = x.shape
263
+ assert L == H * W, "input feature has wrong size"
264
+
265
+ x = x.view(B, H, W, C)
266
+
267
+ # padding
268
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
269
+ if pad_input:
270
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
271
+
272
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
273
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
274
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
275
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
276
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
277
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
278
+
279
+ x = self.norm(x)
280
+ x = self.reduction(x)
281
+
282
+ return x
283
+
284
+
285
+ class BasicLayer(nn.Module):
286
+ """ A basic Swin Transformer layer for one stage.
287
+ Args:
288
+ dim (int): Number of feature channels
289
+ depth (int): Depths of this stage.
290
+ num_heads (int): Number of attention head.
291
+ window_size (int): Local window size. Default: 7.
292
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
293
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
294
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
295
+ drop (float, optional): Dropout rate. Default: 0.0
296
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
297
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
298
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
299
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
300
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
301
+ """
302
+
303
+ def __init__(self,
304
+ dim,
305
+ depth,
306
+ num_heads,
307
+ window_size=7,
308
+ mlp_ratio=4.,
309
+ qkv_bias=True,
310
+ qk_scale=None,
311
+ drop=0.,
312
+ attn_drop=0.,
313
+ drop_path=0.,
314
+ norm_layer=nn.LayerNorm,
315
+ downsample=None,
316
+ use_checkpoint=False):
317
+ super().__init__()
318
+ self.window_size = window_size
319
+ self.shift_size = window_size // 2
320
+ self.depth = depth
321
+ self.use_checkpoint = use_checkpoint
322
+
323
+ # build blocks
324
+ self.blocks = nn.ModuleList([
325
+ SwinTransformerBlock(
326
+ dim=dim,
327
+ num_heads=num_heads,
328
+ window_size=window_size,
329
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
330
+ mlp_ratio=mlp_ratio,
331
+ qkv_bias=qkv_bias,
332
+ qk_scale=qk_scale,
333
+ drop=drop,
334
+ attn_drop=attn_drop,
335
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
336
+ norm_layer=norm_layer)
337
+ for i in range(depth)])
338
+
339
+ # patch merging layer
340
+ if downsample is not None:
341
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
342
+ else:
343
+ self.downsample = None
344
+
345
+ def forward(self, x, H, W):
346
+ """ Forward function.
347
+ Args:
348
+ x: Input feature, tensor size (B, H*W, C).
349
+ H, W: Spatial resolution of the input feature.
350
+ """
351
+
352
+ # calculate attention mask for SW-MSA
353
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
354
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
355
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
356
+ h_slices = (slice(0, -self.window_size),
357
+ slice(-self.window_size, -self.shift_size),
358
+ slice(-self.shift_size, None))
359
+ w_slices = (slice(0, -self.window_size),
360
+ slice(-self.window_size, -self.shift_size),
361
+ slice(-self.shift_size, None))
362
+ cnt = 0
363
+ for h in h_slices:
364
+ for w in w_slices:
365
+ img_mask[:, h, w, :] = cnt
366
+ cnt += 1
367
+
368
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
369
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
370
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
371
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
372
+
373
+ for blk in self.blocks:
374
+ blk.H, blk.W = H, W
375
+ if self.use_checkpoint:
376
+ x = checkpoint.checkpoint(blk, x, attn_mask)
377
+ else:
378
+ x = blk(x, attn_mask)
379
+ if self.downsample is not None:
380
+ x_down = self.downsample(x, H, W)
381
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
382
+ return x, H, W, x_down, Wh, Ww
383
+ else:
384
+ return x, H, W, x, H, W
385
+
386
+
387
+ class PatchEmbed(nn.Module):
388
+ """ Image to Patch Embedding
389
+ Args:
390
+ patch_size (int): Patch token size. Default: 4.
391
+ in_chans (int): Number of input image channels. Default: 3.
392
+ embed_dim (int): Number of linear projection output channels. Default: 96.
393
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
394
+ """
395
+
396
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
397
+ super().__init__()
398
+ patch_size = to_2tuple(patch_size)
399
+ self.patch_size = patch_size
400
+
401
+ self.in_chans = in_chans
402
+ self.embed_dim = embed_dim
403
+
404
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
405
+ if norm_layer is not None:
406
+ self.norm = norm_layer(embed_dim)
407
+ else:
408
+ self.norm = None
409
+
410
+ def forward(self, x):
411
+ """Forward function."""
412
+ # padding
413
+ _, _, H, W = x.size()
414
+ if W % self.patch_size[1] != 0:
415
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
416
+ if H % self.patch_size[0] != 0:
417
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
418
+
419
+ x = self.proj(x) # B C Wh Ww
420
+ if self.norm is not None:
421
+ Wh, Ww = x.size(2), x.size(3)
422
+ x = x.flatten(2).transpose(1, 2)
423
+ x = self.norm(x)
424
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
425
+
426
+ return x
427
+
428
+
429
+ class SwinTransformer(nn.Module):
430
+ """ Swin Transformer backbone.
431
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
432
+ https://arxiv.org/pdf/2103.14030
433
+ Args:
434
+ pretrain_img_size (int): Input image size for training the pretrained model,
435
+ used in absolute postion embedding. Default 224.
436
+ patch_size (int | tuple(int)): Patch size. Default: 4.
437
+ in_chans (int): Number of input image channels. Default: 3.
438
+ embed_dim (int): Number of linear projection output channels. Default: 96.
439
+ depths (tuple[int]): Depths of each Swin Transformer stage.
440
+ num_heads (tuple[int]): Number of attention head of each stage.
441
+ window_size (int): Window size. Default: 7.
442
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
443
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
444
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
445
+ drop_rate (float): Dropout rate.
446
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
447
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
448
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
449
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
450
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
451
+ out_indices (Sequence[int]): Output from which stages.
452
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
453
+ -1 means not freezing any parameters.
454
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
455
+ """
456
+
457
+ def __init__(self,
458
+ pretrain_img_size=224,
459
+ patch_size=4,
460
+ in_chans=3,
461
+ embed_dim=96,
462
+ depths=[2, 2, 6, 2],
463
+ num_heads=[3, 6, 12, 24],
464
+ window_size=7,
465
+ mlp_ratio=4.,
466
+ qkv_bias=True,
467
+ qk_scale=None,
468
+ drop_rate=0.,
469
+ attn_drop_rate=0.,
470
+ drop_path_rate=0.2,
471
+ norm_layer=nn.LayerNorm,
472
+ ape=False,
473
+ patch_norm=True,
474
+ out_indices=(0, 1, 2, 3),
475
+ frozen_stages=-1,
476
+ use_checkpoint=False):
477
+ super().__init__()
478
+
479
+ self.pretrain_img_size = pretrain_img_size
480
+ self.num_layers = len(depths)
481
+ self.embed_dim = embed_dim
482
+ self.ape = ape
483
+ self.patch_norm = patch_norm
484
+ self.out_indices = out_indices
485
+ self.frozen_stages = frozen_stages
486
+
487
+ # split image into non-overlapping patches
488
+ self.patch_embed = PatchEmbed(
489
+ patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
490
+ norm_layer=norm_layer if self.patch_norm else None)
491
+
492
+ # absolute position embedding
493
+ if self.ape:
494
+ pretrain_img_size = to_2tuple(pretrain_img_size)
495
+ patch_size = to_2tuple(patch_size)
496
+ patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
497
+
498
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
499
+ trunc_normal_(self.absolute_pos_embed, std=.02)
500
+
501
+ self.pos_drop = nn.Dropout(p=drop_rate)
502
+
503
+ # stochastic depth
504
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
505
+
506
+ # build layers
507
+ self.layers = nn.ModuleList()
508
+ for i_layer in range(self.num_layers):
509
+ layer = BasicLayer(
510
+ dim=int(embed_dim * 2 ** i_layer),
511
+ depth=depths[i_layer],
512
+ num_heads=num_heads[i_layer],
513
+ window_size=window_size,
514
+ mlp_ratio=mlp_ratio,
515
+ qkv_bias=qkv_bias,
516
+ qk_scale=qk_scale,
517
+ drop=drop_rate,
518
+ attn_drop=attn_drop_rate,
519
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
520
+ norm_layer=norm_layer,
521
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
522
+ use_checkpoint=use_checkpoint)
523
+ self.layers.append(layer)
524
+
525
+ num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
526
+ self.num_features = num_features
527
+
528
+ # add a norm layer for each output
529
+ for i_layer in out_indices:
530
+ layer = norm_layer(num_features[i_layer])
531
+ layer_name = f'norm{i_layer}'
532
+ self.add_module(layer_name, layer)
533
+
534
+ self._freeze_stages()
535
+
536
+ def _freeze_stages(self):
537
+ if self.frozen_stages >= 0:
538
+ self.patch_embed.eval()
539
+ for param in self.patch_embed.parameters():
540
+ param.requires_grad = False
541
+
542
+ if self.frozen_stages >= 1 and self.ape:
543
+ self.absolute_pos_embed.requires_grad = False
544
+
545
+ if self.frozen_stages >= 2:
546
+ self.pos_drop.eval()
547
+ for i in range(0, self.frozen_stages - 1):
548
+ m = self.layers[i]
549
+ m.eval()
550
+ for param in m.parameters():
551
+ param.requires_grad = False
552
+
553
+
554
+ def forward(self, x):
555
+
556
+ x = self.patch_embed(x)
557
+
558
+ Wh, Ww = x.size(2), x.size(3)
559
+ if self.ape:
560
+ # interpolate the position embedding to the corresponding size
561
+ absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
562
+ x = (x + absolute_pos_embed) # B Wh*Ww C
563
+
564
+ outs = [x.contiguous()]
565
+ x = x.flatten(2).transpose(1, 2)
566
+ x = self.pos_drop(x)
567
+
568
+
569
+ for i in range(self.num_layers):
570
+ layer = self.layers[i]
571
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
572
+
573
+
574
+ if i in self.out_indices:
575
+ norm_layer = getattr(self, f'norm{i}')
576
+ x_out = norm_layer(x_out)
577
+
578
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
579
+ outs.append(out)
580
+
581
+
582
+
583
+ return tuple(outs)
584
+
585
+
586
+
587
+
588
+
589
+
590
+
591
+ def get_activation_fn(activation):
592
+ """Return an activation function given a string"""
593
+ if activation == "gelu":
594
+ return F.gelu
595
+
596
+ raise RuntimeError(F"activation should be gelu, not {activation}.")
597
+
598
+
599
+ def make_cbr(in_dim, out_dim):
600
+ return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1), nn.InstanceNorm2d(out_dim), nn.GELU())
601
+
602
+
603
+ def make_cbg(in_dim, out_dim):
604
+ return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1), nn.InstanceNorm2d(out_dim), nn.GELU())
605
+
606
+
607
+ def rescale_to(x, scale_factor: float = 2, interpolation='nearest'):
608
+ return F.interpolate(x, scale_factor=scale_factor, mode=interpolation)
609
+
610
+
611
+ def resize_as(x, y, interpolation='bilinear'):
612
+ return F.interpolate(x, size=y.shape[-2:], mode=interpolation)
613
+
614
+
615
+ def image2patches(x):
616
+ """b c (hg h) (wg w) -> (hg wg b) c h w"""
617
+ x = rearrange(x, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
618
+ return x
619
+
620
+
621
+ def patches2image(x):
622
+ """(hg wg b) c h w -> b c (hg h) (wg w)"""
623
+ x = rearrange(x, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
624
+ return x
625
+ class PositionEmbeddingSine:
626
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
627
+ super().__init__()
628
+ self.num_pos_feats = num_pos_feats
629
+ self.temperature = temperature
630
+ self.normalize = normalize
631
+ if scale is not None and normalize is False:
632
+ raise ValueError("normalize should be True if scale is passed")
633
+ if scale is None:
634
+ scale = 2 * math.pi
635
+ self.scale = scale
636
+ self.dim_t = torch.arange(0, self.num_pos_feats, dtype=torch.float32)
637
+
638
+ def __call__(self, b, h, w):
639
+ device = self.dim_t.device
640
+ mask = torch.zeros([b, h, w], dtype=torch.bool, device=device)
641
+ assert mask is not None
642
+ not_mask = ~mask
643
+ y_embed = not_mask.cumsum(dim=1, dtype=torch.float32)
644
+ x_embed = not_mask.cumsum(dim=2, dtype=torch.float32)
645
+ if self.normalize:
646
+ eps = 1e-6
647
+ y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
648
+ x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
649
+
650
+ dim_t = self.temperature ** (2 * (self.dim_t.to(device) // 2) / self.num_pos_feats)
651
+ pos_x = x_embed[:, :, :, None] / dim_t
652
+ pos_y = y_embed[:, :, :, None] / dim_t
653
+
654
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
655
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
656
+
657
+ return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
658
+
659
+
660
+ class MCLM(nn.Module):
661
+ def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
662
+ super(MCLM, self).__init__()
663
+ self.attention = nn.ModuleList([
664
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
665
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
666
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
667
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
668
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
669
+ ])
670
+
671
+ self.linear1 = nn.Linear(d_model, d_model * 2)
672
+ self.linear2 = nn.Linear(d_model * 2, d_model)
673
+ self.linear3 = nn.Linear(d_model, d_model * 2)
674
+ self.linear4 = nn.Linear(d_model * 2, d_model)
675
+ self.norm1 = nn.LayerNorm(d_model)
676
+ self.norm2 = nn.LayerNorm(d_model)
677
+ self.dropout = nn.Dropout(0.1)
678
+ self.dropout1 = nn.Dropout(0.1)
679
+ self.dropout2 = nn.Dropout(0.1)
680
+ self.activation = get_activation_fn('gelu')
681
+ self.pool_ratios = pool_ratios
682
+ self.p_poses = []
683
+ self.g_pos = None
684
+ self.positional_encoding = PositionEmbeddingSine(num_pos_feats=d_model // 2, normalize=True)
685
+
686
+ def forward(self, l, g):
687
+ """
688
+ l: 4,c,h,w
689
+ g: 1,c,h,w
690
+ """
691
+ b, c, h, w = l.size()
692
+ # 4,c,h,w -> 1,c,2h,2w
693
+ concated_locs = rearrange(l, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
694
+
695
+ pools = []
696
+ for pool_ratio in self.pool_ratios:
697
+ # b,c,h,w
698
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
699
+ pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw)
700
+ pools.append(rearrange(pool, 'b c h w -> (h w) b c'))
701
+ if self.g_pos is None:
702
+ pos_emb = self.positional_encoding(pool.shape[0], pool.shape[2], pool.shape[3])
703
+ pos_emb = rearrange(pos_emb, 'b c h w -> (h w) b c')
704
+ self.p_poses.append(pos_emb)
705
+ pools = torch.cat(pools, 0)
706
+ if self.g_pos is None:
707
+ self.p_poses = torch.cat(self.p_poses, dim=0)
708
+ pos_emb = self.positional_encoding(g.shape[0], g.shape[2], g.shape[3])
709
+ self.g_pos = rearrange(pos_emb, 'b c h w -> (h w) b c')
710
+
711
+ device = pools.device
712
+ self.p_poses = self.p_poses.to(device)
713
+ self.g_pos = self.g_pos.to(device)
714
+
715
+
716
+ # attention between glb (q) & multisensory concated-locs (k,v)
717
+ g_hw_b_c = rearrange(g, 'b c h w -> (h w) b c')
718
+
719
+
720
+ g_hw_b_c = g_hw_b_c + self.dropout1(self.attention[0](g_hw_b_c + self.g_pos, pools + self.p_poses, pools)[0])
721
+ g_hw_b_c = self.norm1(g_hw_b_c)
722
+ g_hw_b_c = g_hw_b_c + self.dropout2(self.linear2(self.dropout(self.activation(self.linear1(g_hw_b_c)).clone())))
723
+ g_hw_b_c = self.norm2(g_hw_b_c)
724
+
725
+ # attention between origin locs (q) & freashed glb (k,v)
726
+ l_hw_b_c = rearrange(l, "b c h w -> (h w) b c")
727
+ _g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w)
728
+ _g_hw_b_c = rearrange(_g_hw_b_c, "(ng h) (nw w) b c -> (h w) (ng nw b) c", ng=2, nw=2)
729
+ outputs_re = []
730
+ for i, (_l, _g) in enumerate(zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))):
731
+ outputs_re.append(self.attention[i + 1](_l, _g, _g)[0]) # (h w) 1 c
732
+ outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c
733
+
734
+ l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
735
+ l_hw_b_c = self.norm1(l_hw_b_c)
736
+ l_hw_b_c = l_hw_b_c + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
737
+ l_hw_b_c = self.norm2(l_hw_b_c)
738
+
739
+ l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
740
+ return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
741
+
742
+
743
+
744
+
745
+
746
+
747
+
748
+
749
+
750
+ class MCRM(nn.Module):
751
+ def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None):
752
+ super(MCRM, self).__init__()
753
+ self.attention = nn.ModuleList([
754
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
755
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
756
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
757
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
758
+ ])
759
+ self.linear3 = nn.Linear(d_model, d_model * 2)
760
+ self.linear4 = nn.Linear(d_model * 2, d_model)
761
+ self.norm1 = nn.LayerNorm(d_model)
762
+ self.norm2 = nn.LayerNorm(d_model)
763
+ self.dropout = nn.Dropout(0.1)
764
+ self.dropout1 = nn.Dropout(0.1)
765
+ self.dropout2 = nn.Dropout(0.1)
766
+ self.sigmoid = nn.Sigmoid()
767
+ self.activation = get_activation_fn('gelu')
768
+ self.sal_conv = nn.Conv2d(d_model, 1, 1)
769
+ self.pool_ratios = pool_ratios
770
+
771
+ def forward(self, x):
772
+ device = x.device
773
+ b, c, h, w = x.size()
774
+ loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
775
+
776
+ patched_glb = rearrange(glb, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
777
+
778
+ token_attention_map = self.sigmoid(self.sal_conv(glb))
779
+ token_attention_map = F.interpolate(token_attention_map, size=patches2image(loc).shape[-2:], mode='nearest')
780
+ loc = loc * rearrange(token_attention_map, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
781
+
782
+ pools = []
783
+ for pool_ratio in self.pool_ratios:
784
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
785
+ pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
786
+ pools.append(rearrange(pool, 'nl c h w -> nl c (h w)')) # nl(4),c,hw
787
+
788
+ pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
789
+ loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
790
+
791
+ outputs = []
792
+ for i, q in enumerate(loc_.unbind(dim=0)): # traverse all local patches
793
+ v = pools[i]
794
+ k = v
795
+ outputs.append(self.attention[i](q, k, v)[0])
796
+
797
+ outputs = torch.cat(outputs, 1)
798
+ src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
799
+ src = self.norm1(src)
800
+ src = src + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(src)).clone())))
801
+ src = self.norm2(src)
802
+ src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
803
+ glb = glb + F.interpolate(patches2image(src), size=glb.shape[-2:], mode='nearest') # freshed glb
804
+
805
+ return torch.cat((src, glb), 0), token_attention_map
806
+
807
+
808
+ class BEN_Base(nn.Module):
809
+ def __init__(self):
810
+ super().__init__()
811
+
812
+ self.backbone = SwinTransformer(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12)
813
+ emb_dim = 128
814
+ self.sideout5 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
815
+ self.sideout4 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
816
+ self.sideout3 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
817
+ self.sideout2 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
818
+ self.sideout1 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
819
+
820
+ self.output5 = make_cbr(1024, emb_dim)
821
+ self.output4 = make_cbr(512, emb_dim)
822
+ self.output3 = make_cbr(256, emb_dim)
823
+ self.output2 = make_cbr(128, emb_dim)
824
+ self.output1 = make_cbr(128, emb_dim)
825
+
826
+ self.multifieldcrossatt = MCLM(emb_dim, 1, [1, 4, 8])
827
+ self.conv1 = make_cbr(emb_dim, emb_dim)
828
+ self.conv2 = make_cbr(emb_dim, emb_dim)
829
+ self.conv3 = make_cbr(emb_dim, emb_dim)
830
+ self.conv4 = make_cbr(emb_dim, emb_dim)
831
+ self.dec_blk1 = MCRM(emb_dim, 1, [2, 4, 8])
832
+ self.dec_blk2 = MCRM(emb_dim, 1, [2, 4, 8])
833
+ self.dec_blk3 = MCRM(emb_dim, 1, [2, 4, 8])
834
+ self.dec_blk4 = MCRM(emb_dim, 1, [2, 4, 8])
835
+
836
+ self.insmask_head = nn.Sequential(
837
+ nn.Conv2d(emb_dim, 384, kernel_size=3, padding=1),
838
+ nn.InstanceNorm2d(384),
839
+ nn.GELU(),
840
+ nn.Conv2d(384, 384, kernel_size=3, padding=1),
841
+ nn.InstanceNorm2d(384),
842
+ nn.GELU(),
843
+ nn.Conv2d(384, emb_dim, kernel_size=3, padding=1)
844
+ )
845
+
846
+ self.shallow = nn.Sequential(nn.Conv2d(3, emb_dim, kernel_size=3, padding=1))
847
+ self.upsample1 = make_cbg(emb_dim, emb_dim)
848
+ self.upsample2 = make_cbg(emb_dim, emb_dim)
849
+ self.output = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
850
+
851
+ for m in self.modules():
852
+ if isinstance(m, nn.GELU) or isinstance(m, nn.Dropout):
853
+ m.inplace = True
854
+
855
+ def forward(self, x):
856
+ device = x.device
857
+ shallow = self.shallow(x)
858
+ glb = rescale_to(x, scale_factor=0.5, interpolation='bilinear')
859
+ loc = image2patches(x)
860
+ input = torch.cat((loc, glb), dim=0)
861
+ feature = self.backbone(input)
862
+ e5 = self.output5(feature[4]) # (5,128,16,16)
863
+ e4 = self.output4(feature[3]) # (5,128,32,32)
864
+ e3 = self.output3(feature[2]) # (5,128,64,64)
865
+ e2 = self.output2(feature[1]) # (5,128,128,128)
866
+ e1 = self.output1(feature[0]) # (5,128,128,128)
867
+ loc_e5, glb_e5 = e5.split([4, 1], dim=0)
868
+ e5 = self.multifieldcrossatt(loc_e5, glb_e5) # (4,128,16,16)
869
+
870
+ e4, tokenattmap4 = self.dec_blk4(e4 + resize_as(e5, e4))
871
+ e4 = self.conv4(e4)
872
+ e3, tokenattmap3 = self.dec_blk3(e3 + resize_as(e4, e3))
873
+ e3 = self.conv3(e3)
874
+ e2, tokenattmap2 = self.dec_blk2(e2 + resize_as(e3, e2))
875
+ e2 = self.conv2(e2)
876
+ e1, tokenattmap1 = self.dec_blk1(e1 + resize_as(e2, e1))
877
+ e1 = self.conv1(e1)
878
+ loc_e1, glb_e1 = e1.split([4, 1], dim=0)
879
+ output1_cat = patches2image(loc_e1) # (1,128,256,256)
880
+ output1_cat = output1_cat + resize_as(glb_e1, output1_cat)
881
+ final_output = self.insmask_head(output1_cat) # (1,128,256,256)
882
+ final_output = final_output + resize_as(shallow, final_output)
883
+ final_output = self.upsample1(rescale_to(final_output))
884
+ final_output = rescale_to(final_output + resize_as(shallow, final_output))
885
+ final_output = self.upsample2(final_output)
886
+ final_output = self.output(final_output)
887
+
888
+ return final_output.sigmoid()
889
+
890
+ @torch.no_grad()
891
+ def inference(self,image):
892
+ image, h, w,original_image = rgb_loader_refiner(image)
893
+
894
+ img_tensor = img_transform(image).unsqueeze(0).to(next(self.parameters()).device)
895
+
896
+ res = self.forward(img_tensor)
897
+
898
+ pred_array = postprocess_image(res, im_size=[w, h])
899
+
900
+ mask_image = Image.fromarray(pred_array, mode='L')
901
+
902
+ blurred_mask = mask_image.filter(ImageFilter.GaussianBlur(radius=1))
903
+
904
+ original_image_rgba = original_image.convert("RGBA")
905
+
906
+ foreground = original_image_rgba.copy()
907
+
908
+ foreground.putalpha(blurred_mask)
909
+
910
+ return blurred_mask, foreground
911
+
912
+ def loadcheckpoints(self,model_path):
913
+ model_dict = torch.load(model_path, map_location="cpu", weights_only=True)
914
+ self.load_state_dict(model_dict['model_state_dict'], strict=True)
915
+ del model_path
916
+
917
+
918
+
919
+
920
+ def rgb_loader_refiner( original_image):
921
+ h, w = original_image.size
922
+ # # Apply EXIF orientation
923
+ image = ImageOps.exif_transpose(original_image)
924
+ # Convert to RGB if necessary
925
+ if image.mode != 'RGB':
926
+ image = image.convert('RGB')
927
+
928
+ # Resize the image
929
+ image = image.resize((1024, 1024), resample=Image.LANCZOS)
930
+
931
+ return image.convert('RGB'), h, w,original_image
932
+
933
+ # Define the image transformation
934
+ img_transform = transforms.Compose([
935
+ transforms.ToTensor(),
936
+ transforms.ConvertImageDtype(torch.float32),
937
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
938
+ ])
939
+
940
+ def postprocess_image(result: torch.Tensor, im_size: list) -> np.ndarray:
941
+ result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear'), 0)
942
+ ma = torch.max(result)
943
+ mi = torch.min(result)
944
+ result = (result - mi) / (ma - mi)
945
+ im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
946
+ im_array = np.squeeze(im_array)
947
+ return im_array
948
+
949
+
950
+
951
+
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ numpy>=1.21.0
2
+ torch>=1.9.0
3
+ einops>=0.6.0
4
+ Pillow>=9.0.0
5
+ timm>=0.6.0
6
+ torchvision>=0.10.0