Upload 9 files
Browse files- .gitattributes +43 -35
- BEN_Base.pth +3 -0
- README.md +109 -0
- config.json +6 -0
- demo.jpg +3 -0
- image.png +3 -0
- inference.py +17 -0
- model.py +951 -0
- requirements.txt +6 -0
.gitattributes
CHANGED
@@ -1,35 +1,43 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
027.jpg filter=lfs diff=lfs merge=lfs -text
|
37 |
+
027.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
4340.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
5442.jpg filter=lfs diff=lfs merge=lfs -text
|
40 |
+
5442.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
image.jpg filter=lfs diff=lfs merge=lfs -text
|
42 |
+
image.png filter=lfs diff=lfs merge=lfs -text
|
43 |
+
demo.jpg filter=lfs diff=lfs merge=lfs -text
|
BEN_Base.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9f00a5804c96afed6e19be09dbbbe56ccaf82cff3d751f44aadb9626b77facfa
|
3 |
+
size 1134588350
|
README.md
CHANGED
@@ -1,3 +1,112 @@
|
|
1 |
---
|
2 |
license: apache-2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
license: apache-2.0
|
3 |
+
pipeline_tag: image-segmentation
|
4 |
+
tags:
|
5 |
+
- BEN
|
6 |
+
- background-remove
|
7 |
+
- mask-generation
|
8 |
+
- Dichotomous image segmentation
|
9 |
+
- background remove
|
10 |
+
- foreground
|
11 |
+
- background
|
12 |
+
- remove background
|
13 |
+
- pytorch
|
14 |
---
|
15 |
+
|
16 |
+
# BEN: Background Erase Network
|
17 |
+
|
18 |
+
[](https://arxiv.org/abs/2501.06230)
|
19 |
+
[](https://github.com/PramaLLC/BEN/)
|
20 |
+
[](https://backgrounderase.net)
|
21 |
+
[](https://paperswithcode.com/sota/dichotomous-image-segmentation-on-dis-vd?p=ben-using-confidence-guided-matting-for)
|
22 |
+
|
23 |
+
## Overview
|
24 |
+
BEN (Background Erase Network) introduces a novel approach to foreground segmentation through its innovative Confidence Guided Matting (CGM) pipeline. The architecture employs a refiner network that targets and processes pixels where the base model exhibits lower confidence levels, resulting in more precise and reliable matting results.
|
25 |
+
|
26 |
+
This repository provides the official code for our model, as detailed in our research paper: [BEN: Background Erase Network](https://arxiv.org/abs/2501.06230).
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
## BEN2 Access
|
31 |
+
BEN2 is now publicly available, trained on DIS5k and our 22K proprietary segmentation dataset. Our enhanced model delivers superior performance in hair matting, 4K processing, object segmentation, and edge refinement. Access the base model on Huggingface, try the full model through our free web demo or integrate BEN2 into your project with our API:
|
32 |
+
- 🤗 [PramaLLC/BEN2](https://huggingface.co/PramaLLC/BEN2)
|
33 |
+
- 🌐 [backgrounderase.net](https://backgrounderase.net)
|
34 |
+
|
35 |
+
## Model Access
|
36 |
+
The base model is publicly available and free to use for commercial use on HuggingFace:
|
37 |
+
- 🤗 [PramaLLC/BEN](https://huggingface.co/PramaLLC/BEN)
|
38 |
+
|
39 |
+
|
40 |
+
## Contact US
|
41 |
+
- For access to our commercial model email us at [email protected]
|
42 |
+
- Our website: https://pramadevelopment.com/
|
43 |
+
- Follow us on X: https://x.com/PramaResearch/
|
44 |
+
|
45 |
+
|
46 |
+
## Quick Start Code (Inside Cloned Repo)
|
47 |
+
|
48 |
+
```python
|
49 |
+
import model
|
50 |
+
from PIL import Image
|
51 |
+
import torch
|
52 |
+
|
53 |
+
|
54 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
55 |
+
|
56 |
+
file = "./image.png" # input image
|
57 |
+
|
58 |
+
model = model.BEN_Base().to(device).eval() #init pipeline
|
59 |
+
|
60 |
+
model.loadcheckpoints("./BEN_Base.pth")
|
61 |
+
image = Image.open(file)
|
62 |
+
mask, foreground = model.inference(image)
|
63 |
+
|
64 |
+
mask.save("./mask.png")
|
65 |
+
foreground.save("./foreground.png")
|
66 |
+
```
|
67 |
+
|
68 |
+
# BEN SOA Benchmarks on Disk 5k Eval
|
69 |
+
|
70 |
+

|
71 |
+
|
72 |
+
|
73 |
+
### BEN_Base + BEN_Refiner (commercial model please contact us for more information):
|
74 |
+
- MAE: 0.0270
|
75 |
+
- DICE: 0.8989
|
76 |
+
- IOU: 0.8506
|
77 |
+
- BER: 0.0496
|
78 |
+
- ACC: 0.9740
|
79 |
+
|
80 |
+
### BEN_Base (94 million parameters):
|
81 |
+
- MAE: 0.0309
|
82 |
+
- DICE: 0.8806
|
83 |
+
- IOU: 0.8371
|
84 |
+
- BER: 0.0516
|
85 |
+
- ACC: 0.9718
|
86 |
+
|
87 |
+
### MVANet (old SOTA):
|
88 |
+
- MAE: 0.0353
|
89 |
+
- DICE: 0.8676
|
90 |
+
- IOU: 0.8104
|
91 |
+
- BER: 0.0639
|
92 |
+
- ACC: 0.9660
|
93 |
+
|
94 |
+
|
95 |
+
### BiRefNet(not tested in house):
|
96 |
+
- MAE: 0.038
|
97 |
+
|
98 |
+
|
99 |
+
### InSPyReNet (not tested in house):
|
100 |
+
- MAE: 0.042
|
101 |
+
|
102 |
+
|
103 |
+
|
104 |
+
## Features
|
105 |
+
- Background removal from images
|
106 |
+
- Generates both binary mask and foreground image
|
107 |
+
- CUDA support for GPU acceleration
|
108 |
+
- Simple API for easy integration
|
109 |
+
|
110 |
+
## Installation
|
111 |
+
1. Clone Repo
|
112 |
+
2. Install requirements.txt
|
config.json
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "PramaLLC/BEN",
|
3 |
+
"architectures": ["PramaBEN_Base"],
|
4 |
+
"version": "1.0",
|
5 |
+
"torch_dtype": "float32"
|
6 |
+
}
|
demo.jpg
ADDED
![]() |
Git LFS Details
|
image.png
ADDED
![]() |
Git LFS Details
|
inference.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import model
|
2 |
+
from PIL import Image
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
7 |
+
|
8 |
+
file = "./image.png" # input image
|
9 |
+
|
10 |
+
model = model.BEN_Base().to(device).eval() #init pipeline
|
11 |
+
|
12 |
+
model.loadcheckpoints("./BEN_Base.pth")
|
13 |
+
image = Image.open(file)
|
14 |
+
mask, foreground = model.inference(image)
|
15 |
+
|
16 |
+
mask.save("./mask.png")
|
17 |
+
foreground.save("./foreground.png")
|
model.py
ADDED
@@ -0,0 +1,951 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import torch.utils.checkpoint as checkpoint
|
7 |
+
from einops import rearrange
|
8 |
+
from PIL import Image, ImageFilter, ImageOps
|
9 |
+
from timm.layers import DropPath, to_2tuple, trunc_normal_
|
10 |
+
from torchvision import transforms
|
11 |
+
|
12 |
+
class Mlp(nn.Module):
|
13 |
+
""" Multilayer perceptron."""
|
14 |
+
|
15 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
16 |
+
super().__init__()
|
17 |
+
out_features = out_features or in_features
|
18 |
+
hidden_features = hidden_features or in_features
|
19 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
20 |
+
self.act = act_layer()
|
21 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
22 |
+
self.drop = nn.Dropout(drop)
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
x = self.fc1(x)
|
26 |
+
x = self.act(x)
|
27 |
+
x = self.drop(x)
|
28 |
+
x = self.fc2(x)
|
29 |
+
x = self.drop(x)
|
30 |
+
return x
|
31 |
+
|
32 |
+
|
33 |
+
def window_partition(x, window_size):
|
34 |
+
"""
|
35 |
+
Args:
|
36 |
+
x: (B, H, W, C)
|
37 |
+
window_size (int): window size
|
38 |
+
Returns:
|
39 |
+
windows: (num_windows*B, window_size, window_size, C)
|
40 |
+
"""
|
41 |
+
B, H, W, C = x.shape
|
42 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
43 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
44 |
+
return windows
|
45 |
+
|
46 |
+
|
47 |
+
def window_reverse(windows, window_size, H, W):
|
48 |
+
"""
|
49 |
+
Args:
|
50 |
+
windows: (num_windows*B, window_size, window_size, C)
|
51 |
+
window_size (int): Window size
|
52 |
+
H (int): Height of image
|
53 |
+
W (int): Width of image
|
54 |
+
Returns:
|
55 |
+
x: (B, H, W, C)
|
56 |
+
"""
|
57 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
58 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
59 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
60 |
+
return x
|
61 |
+
|
62 |
+
|
63 |
+
class WindowAttention(nn.Module):
|
64 |
+
""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
65 |
+
It supports both of shifted and non-shifted window.
|
66 |
+
Args:
|
67 |
+
dim (int): Number of input channels.
|
68 |
+
window_size (tuple[int]): The height and width of the window.
|
69 |
+
num_heads (int): Number of attention heads.
|
70 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
71 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
72 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
73 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
74 |
+
"""
|
75 |
+
|
76 |
+
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
77 |
+
|
78 |
+
super().__init__()
|
79 |
+
self.dim = dim
|
80 |
+
self.window_size = window_size # Wh, Ww
|
81 |
+
self.num_heads = num_heads
|
82 |
+
head_dim = dim // num_heads
|
83 |
+
self.scale = qk_scale or head_dim ** -0.5
|
84 |
+
|
85 |
+
# define a parameter table of relative position bias
|
86 |
+
self.relative_position_bias_table = nn.Parameter(
|
87 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
88 |
+
|
89 |
+
# get pair-wise relative position index for each token inside the window
|
90 |
+
coords_h = torch.arange(self.window_size[0])
|
91 |
+
coords_w = torch.arange(self.window_size[1])
|
92 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
93 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
94 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
95 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
96 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
97 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
98 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
99 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
100 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
101 |
+
|
102 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
103 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
104 |
+
self.proj = nn.Linear(dim, dim)
|
105 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
106 |
+
|
107 |
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
108 |
+
self.softmax = nn.Softmax(dim=-1)
|
109 |
+
|
110 |
+
def forward(self, x, mask=None):
|
111 |
+
""" Forward function.
|
112 |
+
Args:
|
113 |
+
x: input features with shape of (num_windows*B, N, C)
|
114 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
115 |
+
"""
|
116 |
+
B_, N, C = x.shape
|
117 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
118 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
119 |
+
|
120 |
+
q = q * self.scale
|
121 |
+
attn = (q @ k.transpose(-2, -1))
|
122 |
+
|
123 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
124 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
125 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
126 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
127 |
+
|
128 |
+
if mask is not None:
|
129 |
+
nW = mask.shape[0]
|
130 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
131 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
132 |
+
attn = self.softmax(attn)
|
133 |
+
else:
|
134 |
+
attn = self.softmax(attn)
|
135 |
+
|
136 |
+
attn = self.attn_drop(attn)
|
137 |
+
|
138 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
139 |
+
x = self.proj(x)
|
140 |
+
x = self.proj_drop(x)
|
141 |
+
return x
|
142 |
+
|
143 |
+
|
144 |
+
class SwinTransformerBlock(nn.Module):
|
145 |
+
""" Swin Transformer Block.
|
146 |
+
Args:
|
147 |
+
dim (int): Number of input channels.
|
148 |
+
num_heads (int): Number of attention heads.
|
149 |
+
window_size (int): Window size.
|
150 |
+
shift_size (int): Shift size for SW-MSA.
|
151 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
152 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
153 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
154 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
155 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
156 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
157 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
158 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
159 |
+
"""
|
160 |
+
|
161 |
+
def __init__(self, dim, num_heads, window_size=7, shift_size=0,
|
162 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
163 |
+
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
164 |
+
super().__init__()
|
165 |
+
self.dim = dim
|
166 |
+
self.num_heads = num_heads
|
167 |
+
self.window_size = window_size
|
168 |
+
self.shift_size = shift_size
|
169 |
+
self.mlp_ratio = mlp_ratio
|
170 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
171 |
+
|
172 |
+
self.norm1 = norm_layer(dim)
|
173 |
+
self.attn = WindowAttention(
|
174 |
+
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
175 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
176 |
+
|
177 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
178 |
+
self.norm2 = norm_layer(dim)
|
179 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
180 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
181 |
+
|
182 |
+
self.H = None
|
183 |
+
self.W = None
|
184 |
+
|
185 |
+
def forward(self, x, mask_matrix):
|
186 |
+
""" Forward function.
|
187 |
+
Args:
|
188 |
+
x: Input feature, tensor size (B, H*W, C).
|
189 |
+
H, W: Spatial resolution of the input feature.
|
190 |
+
mask_matrix: Attention mask for cyclic shift.
|
191 |
+
"""
|
192 |
+
B, L, C = x.shape
|
193 |
+
H, W = self.H, self.W
|
194 |
+
assert L == H * W, "input feature has wrong size"
|
195 |
+
|
196 |
+
shortcut = x
|
197 |
+
x = self.norm1(x)
|
198 |
+
x = x.view(B, H, W, C)
|
199 |
+
|
200 |
+
# pad feature maps to multiples of window size
|
201 |
+
pad_l = pad_t = 0
|
202 |
+
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
203 |
+
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
204 |
+
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
205 |
+
_, Hp, Wp, _ = x.shape
|
206 |
+
|
207 |
+
# cyclic shift
|
208 |
+
if self.shift_size > 0:
|
209 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
210 |
+
attn_mask = mask_matrix
|
211 |
+
else:
|
212 |
+
shifted_x = x
|
213 |
+
attn_mask = None
|
214 |
+
|
215 |
+
# partition windows
|
216 |
+
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
217 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
218 |
+
|
219 |
+
# W-MSA/SW-MSA
|
220 |
+
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
|
221 |
+
|
222 |
+
# merge windows
|
223 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
224 |
+
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
|
225 |
+
|
226 |
+
# reverse cyclic shift
|
227 |
+
if self.shift_size > 0:
|
228 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
229 |
+
else:
|
230 |
+
x = shifted_x
|
231 |
+
|
232 |
+
if pad_r > 0 or pad_b > 0:
|
233 |
+
x = x[:, :H, :W, :].contiguous()
|
234 |
+
|
235 |
+
x = x.view(B, H * W, C)
|
236 |
+
|
237 |
+
# FFN
|
238 |
+
x = shortcut + self.drop_path(x)
|
239 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
240 |
+
|
241 |
+
return x
|
242 |
+
|
243 |
+
|
244 |
+
class PatchMerging(nn.Module):
|
245 |
+
""" Patch Merging Layer
|
246 |
+
Args:
|
247 |
+
dim (int): Number of input channels.
|
248 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
249 |
+
"""
|
250 |
+
def __init__(self, dim, norm_layer=nn.LayerNorm):
|
251 |
+
super().__init__()
|
252 |
+
self.dim = dim
|
253 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
254 |
+
self.norm = norm_layer(4 * dim)
|
255 |
+
|
256 |
+
def forward(self, x, H, W):
|
257 |
+
""" Forward function.
|
258 |
+
Args:
|
259 |
+
x: Input feature, tensor size (B, H*W, C).
|
260 |
+
H, W: Spatial resolution of the input feature.
|
261 |
+
"""
|
262 |
+
B, L, C = x.shape
|
263 |
+
assert L == H * W, "input feature has wrong size"
|
264 |
+
|
265 |
+
x = x.view(B, H, W, C)
|
266 |
+
|
267 |
+
# padding
|
268 |
+
pad_input = (H % 2 == 1) or (W % 2 == 1)
|
269 |
+
if pad_input:
|
270 |
+
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
|
271 |
+
|
272 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
273 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
274 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
275 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
276 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
277 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
278 |
+
|
279 |
+
x = self.norm(x)
|
280 |
+
x = self.reduction(x)
|
281 |
+
|
282 |
+
return x
|
283 |
+
|
284 |
+
|
285 |
+
class BasicLayer(nn.Module):
|
286 |
+
""" A basic Swin Transformer layer for one stage.
|
287 |
+
Args:
|
288 |
+
dim (int): Number of feature channels
|
289 |
+
depth (int): Depths of this stage.
|
290 |
+
num_heads (int): Number of attention head.
|
291 |
+
window_size (int): Local window size. Default: 7.
|
292 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
293 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
294 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
295 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
296 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
297 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
298 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
299 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
300 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
301 |
+
"""
|
302 |
+
|
303 |
+
def __init__(self,
|
304 |
+
dim,
|
305 |
+
depth,
|
306 |
+
num_heads,
|
307 |
+
window_size=7,
|
308 |
+
mlp_ratio=4.,
|
309 |
+
qkv_bias=True,
|
310 |
+
qk_scale=None,
|
311 |
+
drop=0.,
|
312 |
+
attn_drop=0.,
|
313 |
+
drop_path=0.,
|
314 |
+
norm_layer=nn.LayerNorm,
|
315 |
+
downsample=None,
|
316 |
+
use_checkpoint=False):
|
317 |
+
super().__init__()
|
318 |
+
self.window_size = window_size
|
319 |
+
self.shift_size = window_size // 2
|
320 |
+
self.depth = depth
|
321 |
+
self.use_checkpoint = use_checkpoint
|
322 |
+
|
323 |
+
# build blocks
|
324 |
+
self.blocks = nn.ModuleList([
|
325 |
+
SwinTransformerBlock(
|
326 |
+
dim=dim,
|
327 |
+
num_heads=num_heads,
|
328 |
+
window_size=window_size,
|
329 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
330 |
+
mlp_ratio=mlp_ratio,
|
331 |
+
qkv_bias=qkv_bias,
|
332 |
+
qk_scale=qk_scale,
|
333 |
+
drop=drop,
|
334 |
+
attn_drop=attn_drop,
|
335 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
336 |
+
norm_layer=norm_layer)
|
337 |
+
for i in range(depth)])
|
338 |
+
|
339 |
+
# patch merging layer
|
340 |
+
if downsample is not None:
|
341 |
+
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
|
342 |
+
else:
|
343 |
+
self.downsample = None
|
344 |
+
|
345 |
+
def forward(self, x, H, W):
|
346 |
+
""" Forward function.
|
347 |
+
Args:
|
348 |
+
x: Input feature, tensor size (B, H*W, C).
|
349 |
+
H, W: Spatial resolution of the input feature.
|
350 |
+
"""
|
351 |
+
|
352 |
+
# calculate attention mask for SW-MSA
|
353 |
+
Hp = int(np.ceil(H / self.window_size)) * self.window_size
|
354 |
+
Wp = int(np.ceil(W / self.window_size)) * self.window_size
|
355 |
+
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
|
356 |
+
h_slices = (slice(0, -self.window_size),
|
357 |
+
slice(-self.window_size, -self.shift_size),
|
358 |
+
slice(-self.shift_size, None))
|
359 |
+
w_slices = (slice(0, -self.window_size),
|
360 |
+
slice(-self.window_size, -self.shift_size),
|
361 |
+
slice(-self.shift_size, None))
|
362 |
+
cnt = 0
|
363 |
+
for h in h_slices:
|
364 |
+
for w in w_slices:
|
365 |
+
img_mask[:, h, w, :] = cnt
|
366 |
+
cnt += 1
|
367 |
+
|
368 |
+
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
369 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
370 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
371 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
372 |
+
|
373 |
+
for blk in self.blocks:
|
374 |
+
blk.H, blk.W = H, W
|
375 |
+
if self.use_checkpoint:
|
376 |
+
x = checkpoint.checkpoint(blk, x, attn_mask)
|
377 |
+
else:
|
378 |
+
x = blk(x, attn_mask)
|
379 |
+
if self.downsample is not None:
|
380 |
+
x_down = self.downsample(x, H, W)
|
381 |
+
Wh, Ww = (H + 1) // 2, (W + 1) // 2
|
382 |
+
return x, H, W, x_down, Wh, Ww
|
383 |
+
else:
|
384 |
+
return x, H, W, x, H, W
|
385 |
+
|
386 |
+
|
387 |
+
class PatchEmbed(nn.Module):
|
388 |
+
""" Image to Patch Embedding
|
389 |
+
Args:
|
390 |
+
patch_size (int): Patch token size. Default: 4.
|
391 |
+
in_chans (int): Number of input image channels. Default: 3.
|
392 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
393 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
394 |
+
"""
|
395 |
+
|
396 |
+
def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
397 |
+
super().__init__()
|
398 |
+
patch_size = to_2tuple(patch_size)
|
399 |
+
self.patch_size = patch_size
|
400 |
+
|
401 |
+
self.in_chans = in_chans
|
402 |
+
self.embed_dim = embed_dim
|
403 |
+
|
404 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
405 |
+
if norm_layer is not None:
|
406 |
+
self.norm = norm_layer(embed_dim)
|
407 |
+
else:
|
408 |
+
self.norm = None
|
409 |
+
|
410 |
+
def forward(self, x):
|
411 |
+
"""Forward function."""
|
412 |
+
# padding
|
413 |
+
_, _, H, W = x.size()
|
414 |
+
if W % self.patch_size[1] != 0:
|
415 |
+
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
|
416 |
+
if H % self.patch_size[0] != 0:
|
417 |
+
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
418 |
+
|
419 |
+
x = self.proj(x) # B C Wh Ww
|
420 |
+
if self.norm is not None:
|
421 |
+
Wh, Ww = x.size(2), x.size(3)
|
422 |
+
x = x.flatten(2).transpose(1, 2)
|
423 |
+
x = self.norm(x)
|
424 |
+
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
|
425 |
+
|
426 |
+
return x
|
427 |
+
|
428 |
+
|
429 |
+
class SwinTransformer(nn.Module):
|
430 |
+
""" Swin Transformer backbone.
|
431 |
+
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
432 |
+
https://arxiv.org/pdf/2103.14030
|
433 |
+
Args:
|
434 |
+
pretrain_img_size (int): Input image size for training the pretrained model,
|
435 |
+
used in absolute postion embedding. Default 224.
|
436 |
+
patch_size (int | tuple(int)): Patch size. Default: 4.
|
437 |
+
in_chans (int): Number of input image channels. Default: 3.
|
438 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
439 |
+
depths (tuple[int]): Depths of each Swin Transformer stage.
|
440 |
+
num_heads (tuple[int]): Number of attention head of each stage.
|
441 |
+
window_size (int): Window size. Default: 7.
|
442 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
443 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
444 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
|
445 |
+
drop_rate (float): Dropout rate.
|
446 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0.
|
447 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
|
448 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
449 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
|
450 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True.
|
451 |
+
out_indices (Sequence[int]): Output from which stages.
|
452 |
+
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
453 |
+
-1 means not freezing any parameters.
|
454 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
455 |
+
"""
|
456 |
+
|
457 |
+
def __init__(self,
|
458 |
+
pretrain_img_size=224,
|
459 |
+
patch_size=4,
|
460 |
+
in_chans=3,
|
461 |
+
embed_dim=96,
|
462 |
+
depths=[2, 2, 6, 2],
|
463 |
+
num_heads=[3, 6, 12, 24],
|
464 |
+
window_size=7,
|
465 |
+
mlp_ratio=4.,
|
466 |
+
qkv_bias=True,
|
467 |
+
qk_scale=None,
|
468 |
+
drop_rate=0.,
|
469 |
+
attn_drop_rate=0.,
|
470 |
+
drop_path_rate=0.2,
|
471 |
+
norm_layer=nn.LayerNorm,
|
472 |
+
ape=False,
|
473 |
+
patch_norm=True,
|
474 |
+
out_indices=(0, 1, 2, 3),
|
475 |
+
frozen_stages=-1,
|
476 |
+
use_checkpoint=False):
|
477 |
+
super().__init__()
|
478 |
+
|
479 |
+
self.pretrain_img_size = pretrain_img_size
|
480 |
+
self.num_layers = len(depths)
|
481 |
+
self.embed_dim = embed_dim
|
482 |
+
self.ape = ape
|
483 |
+
self.patch_norm = patch_norm
|
484 |
+
self.out_indices = out_indices
|
485 |
+
self.frozen_stages = frozen_stages
|
486 |
+
|
487 |
+
# split image into non-overlapping patches
|
488 |
+
self.patch_embed = PatchEmbed(
|
489 |
+
patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
|
490 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
491 |
+
|
492 |
+
# absolute position embedding
|
493 |
+
if self.ape:
|
494 |
+
pretrain_img_size = to_2tuple(pretrain_img_size)
|
495 |
+
patch_size = to_2tuple(patch_size)
|
496 |
+
patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
|
497 |
+
|
498 |
+
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
|
499 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
500 |
+
|
501 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
502 |
+
|
503 |
+
# stochastic depth
|
504 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
505 |
+
|
506 |
+
# build layers
|
507 |
+
self.layers = nn.ModuleList()
|
508 |
+
for i_layer in range(self.num_layers):
|
509 |
+
layer = BasicLayer(
|
510 |
+
dim=int(embed_dim * 2 ** i_layer),
|
511 |
+
depth=depths[i_layer],
|
512 |
+
num_heads=num_heads[i_layer],
|
513 |
+
window_size=window_size,
|
514 |
+
mlp_ratio=mlp_ratio,
|
515 |
+
qkv_bias=qkv_bias,
|
516 |
+
qk_scale=qk_scale,
|
517 |
+
drop=drop_rate,
|
518 |
+
attn_drop=attn_drop_rate,
|
519 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
520 |
+
norm_layer=norm_layer,
|
521 |
+
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
522 |
+
use_checkpoint=use_checkpoint)
|
523 |
+
self.layers.append(layer)
|
524 |
+
|
525 |
+
num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
|
526 |
+
self.num_features = num_features
|
527 |
+
|
528 |
+
# add a norm layer for each output
|
529 |
+
for i_layer in out_indices:
|
530 |
+
layer = norm_layer(num_features[i_layer])
|
531 |
+
layer_name = f'norm{i_layer}'
|
532 |
+
self.add_module(layer_name, layer)
|
533 |
+
|
534 |
+
self._freeze_stages()
|
535 |
+
|
536 |
+
def _freeze_stages(self):
|
537 |
+
if self.frozen_stages >= 0:
|
538 |
+
self.patch_embed.eval()
|
539 |
+
for param in self.patch_embed.parameters():
|
540 |
+
param.requires_grad = False
|
541 |
+
|
542 |
+
if self.frozen_stages >= 1 and self.ape:
|
543 |
+
self.absolute_pos_embed.requires_grad = False
|
544 |
+
|
545 |
+
if self.frozen_stages >= 2:
|
546 |
+
self.pos_drop.eval()
|
547 |
+
for i in range(0, self.frozen_stages - 1):
|
548 |
+
m = self.layers[i]
|
549 |
+
m.eval()
|
550 |
+
for param in m.parameters():
|
551 |
+
param.requires_grad = False
|
552 |
+
|
553 |
+
|
554 |
+
def forward(self, x):
|
555 |
+
|
556 |
+
x = self.patch_embed(x)
|
557 |
+
|
558 |
+
Wh, Ww = x.size(2), x.size(3)
|
559 |
+
if self.ape:
|
560 |
+
# interpolate the position embedding to the corresponding size
|
561 |
+
absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
|
562 |
+
x = (x + absolute_pos_embed) # B Wh*Ww C
|
563 |
+
|
564 |
+
outs = [x.contiguous()]
|
565 |
+
x = x.flatten(2).transpose(1, 2)
|
566 |
+
x = self.pos_drop(x)
|
567 |
+
|
568 |
+
|
569 |
+
for i in range(self.num_layers):
|
570 |
+
layer = self.layers[i]
|
571 |
+
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
|
572 |
+
|
573 |
+
|
574 |
+
if i in self.out_indices:
|
575 |
+
norm_layer = getattr(self, f'norm{i}')
|
576 |
+
x_out = norm_layer(x_out)
|
577 |
+
|
578 |
+
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
|
579 |
+
outs.append(out)
|
580 |
+
|
581 |
+
|
582 |
+
|
583 |
+
return tuple(outs)
|
584 |
+
|
585 |
+
|
586 |
+
|
587 |
+
|
588 |
+
|
589 |
+
|
590 |
+
|
591 |
+
def get_activation_fn(activation):
|
592 |
+
"""Return an activation function given a string"""
|
593 |
+
if activation == "gelu":
|
594 |
+
return F.gelu
|
595 |
+
|
596 |
+
raise RuntimeError(F"activation should be gelu, not {activation}.")
|
597 |
+
|
598 |
+
|
599 |
+
def make_cbr(in_dim, out_dim):
|
600 |
+
return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1), nn.InstanceNorm2d(out_dim), nn.GELU())
|
601 |
+
|
602 |
+
|
603 |
+
def make_cbg(in_dim, out_dim):
|
604 |
+
return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1), nn.InstanceNorm2d(out_dim), nn.GELU())
|
605 |
+
|
606 |
+
|
607 |
+
def rescale_to(x, scale_factor: float = 2, interpolation='nearest'):
|
608 |
+
return F.interpolate(x, scale_factor=scale_factor, mode=interpolation)
|
609 |
+
|
610 |
+
|
611 |
+
def resize_as(x, y, interpolation='bilinear'):
|
612 |
+
return F.interpolate(x, size=y.shape[-2:], mode=interpolation)
|
613 |
+
|
614 |
+
|
615 |
+
def image2patches(x):
|
616 |
+
"""b c (hg h) (wg w) -> (hg wg b) c h w"""
|
617 |
+
x = rearrange(x, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
|
618 |
+
return x
|
619 |
+
|
620 |
+
|
621 |
+
def patches2image(x):
|
622 |
+
"""(hg wg b) c h w -> b c (hg h) (wg w)"""
|
623 |
+
x = rearrange(x, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
|
624 |
+
return x
|
625 |
+
class PositionEmbeddingSine:
|
626 |
+
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
627 |
+
super().__init__()
|
628 |
+
self.num_pos_feats = num_pos_feats
|
629 |
+
self.temperature = temperature
|
630 |
+
self.normalize = normalize
|
631 |
+
if scale is not None and normalize is False:
|
632 |
+
raise ValueError("normalize should be True if scale is passed")
|
633 |
+
if scale is None:
|
634 |
+
scale = 2 * math.pi
|
635 |
+
self.scale = scale
|
636 |
+
self.dim_t = torch.arange(0, self.num_pos_feats, dtype=torch.float32)
|
637 |
+
|
638 |
+
def __call__(self, b, h, w):
|
639 |
+
device = self.dim_t.device
|
640 |
+
mask = torch.zeros([b, h, w], dtype=torch.bool, device=device)
|
641 |
+
assert mask is not None
|
642 |
+
not_mask = ~mask
|
643 |
+
y_embed = not_mask.cumsum(dim=1, dtype=torch.float32)
|
644 |
+
x_embed = not_mask.cumsum(dim=2, dtype=torch.float32)
|
645 |
+
if self.normalize:
|
646 |
+
eps = 1e-6
|
647 |
+
y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
|
648 |
+
x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
|
649 |
+
|
650 |
+
dim_t = self.temperature ** (2 * (self.dim_t.to(device) // 2) / self.num_pos_feats)
|
651 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
652 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
653 |
+
|
654 |
+
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
655 |
+
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
656 |
+
|
657 |
+
return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
658 |
+
|
659 |
+
|
660 |
+
class MCLM(nn.Module):
|
661 |
+
def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
|
662 |
+
super(MCLM, self).__init__()
|
663 |
+
self.attention = nn.ModuleList([
|
664 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
665 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
666 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
667 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
668 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
|
669 |
+
])
|
670 |
+
|
671 |
+
self.linear1 = nn.Linear(d_model, d_model * 2)
|
672 |
+
self.linear2 = nn.Linear(d_model * 2, d_model)
|
673 |
+
self.linear3 = nn.Linear(d_model, d_model * 2)
|
674 |
+
self.linear4 = nn.Linear(d_model * 2, d_model)
|
675 |
+
self.norm1 = nn.LayerNorm(d_model)
|
676 |
+
self.norm2 = nn.LayerNorm(d_model)
|
677 |
+
self.dropout = nn.Dropout(0.1)
|
678 |
+
self.dropout1 = nn.Dropout(0.1)
|
679 |
+
self.dropout2 = nn.Dropout(0.1)
|
680 |
+
self.activation = get_activation_fn('gelu')
|
681 |
+
self.pool_ratios = pool_ratios
|
682 |
+
self.p_poses = []
|
683 |
+
self.g_pos = None
|
684 |
+
self.positional_encoding = PositionEmbeddingSine(num_pos_feats=d_model // 2, normalize=True)
|
685 |
+
|
686 |
+
def forward(self, l, g):
|
687 |
+
"""
|
688 |
+
l: 4,c,h,w
|
689 |
+
g: 1,c,h,w
|
690 |
+
"""
|
691 |
+
b, c, h, w = l.size()
|
692 |
+
# 4,c,h,w -> 1,c,2h,2w
|
693 |
+
concated_locs = rearrange(l, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
|
694 |
+
|
695 |
+
pools = []
|
696 |
+
for pool_ratio in self.pool_ratios:
|
697 |
+
# b,c,h,w
|
698 |
+
tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
|
699 |
+
pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw)
|
700 |
+
pools.append(rearrange(pool, 'b c h w -> (h w) b c'))
|
701 |
+
if self.g_pos is None:
|
702 |
+
pos_emb = self.positional_encoding(pool.shape[0], pool.shape[2], pool.shape[3])
|
703 |
+
pos_emb = rearrange(pos_emb, 'b c h w -> (h w) b c')
|
704 |
+
self.p_poses.append(pos_emb)
|
705 |
+
pools = torch.cat(pools, 0)
|
706 |
+
if self.g_pos is None:
|
707 |
+
self.p_poses = torch.cat(self.p_poses, dim=0)
|
708 |
+
pos_emb = self.positional_encoding(g.shape[0], g.shape[2], g.shape[3])
|
709 |
+
self.g_pos = rearrange(pos_emb, 'b c h w -> (h w) b c')
|
710 |
+
|
711 |
+
device = pools.device
|
712 |
+
self.p_poses = self.p_poses.to(device)
|
713 |
+
self.g_pos = self.g_pos.to(device)
|
714 |
+
|
715 |
+
|
716 |
+
# attention between glb (q) & multisensory concated-locs (k,v)
|
717 |
+
g_hw_b_c = rearrange(g, 'b c h w -> (h w) b c')
|
718 |
+
|
719 |
+
|
720 |
+
g_hw_b_c = g_hw_b_c + self.dropout1(self.attention[0](g_hw_b_c + self.g_pos, pools + self.p_poses, pools)[0])
|
721 |
+
g_hw_b_c = self.norm1(g_hw_b_c)
|
722 |
+
g_hw_b_c = g_hw_b_c + self.dropout2(self.linear2(self.dropout(self.activation(self.linear1(g_hw_b_c)).clone())))
|
723 |
+
g_hw_b_c = self.norm2(g_hw_b_c)
|
724 |
+
|
725 |
+
# attention between origin locs (q) & freashed glb (k,v)
|
726 |
+
l_hw_b_c = rearrange(l, "b c h w -> (h w) b c")
|
727 |
+
_g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w)
|
728 |
+
_g_hw_b_c = rearrange(_g_hw_b_c, "(ng h) (nw w) b c -> (h w) (ng nw b) c", ng=2, nw=2)
|
729 |
+
outputs_re = []
|
730 |
+
for i, (_l, _g) in enumerate(zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))):
|
731 |
+
outputs_re.append(self.attention[i + 1](_l, _g, _g)[0]) # (h w) 1 c
|
732 |
+
outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c
|
733 |
+
|
734 |
+
l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
|
735 |
+
l_hw_b_c = self.norm1(l_hw_b_c)
|
736 |
+
l_hw_b_c = l_hw_b_c + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
|
737 |
+
l_hw_b_c = self.norm2(l_hw_b_c)
|
738 |
+
|
739 |
+
l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
|
740 |
+
return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
|
741 |
+
|
742 |
+
|
743 |
+
|
744 |
+
|
745 |
+
|
746 |
+
|
747 |
+
|
748 |
+
|
749 |
+
|
750 |
+
class MCRM(nn.Module):
|
751 |
+
def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None):
|
752 |
+
super(MCRM, self).__init__()
|
753 |
+
self.attention = nn.ModuleList([
|
754 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
755 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
756 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
757 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
|
758 |
+
])
|
759 |
+
self.linear3 = nn.Linear(d_model, d_model * 2)
|
760 |
+
self.linear4 = nn.Linear(d_model * 2, d_model)
|
761 |
+
self.norm1 = nn.LayerNorm(d_model)
|
762 |
+
self.norm2 = nn.LayerNorm(d_model)
|
763 |
+
self.dropout = nn.Dropout(0.1)
|
764 |
+
self.dropout1 = nn.Dropout(0.1)
|
765 |
+
self.dropout2 = nn.Dropout(0.1)
|
766 |
+
self.sigmoid = nn.Sigmoid()
|
767 |
+
self.activation = get_activation_fn('gelu')
|
768 |
+
self.sal_conv = nn.Conv2d(d_model, 1, 1)
|
769 |
+
self.pool_ratios = pool_ratios
|
770 |
+
|
771 |
+
def forward(self, x):
|
772 |
+
device = x.device
|
773 |
+
b, c, h, w = x.size()
|
774 |
+
loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
|
775 |
+
|
776 |
+
patched_glb = rearrange(glb, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
|
777 |
+
|
778 |
+
token_attention_map = self.sigmoid(self.sal_conv(glb))
|
779 |
+
token_attention_map = F.interpolate(token_attention_map, size=patches2image(loc).shape[-2:], mode='nearest')
|
780 |
+
loc = loc * rearrange(token_attention_map, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
|
781 |
+
|
782 |
+
pools = []
|
783 |
+
for pool_ratio in self.pool_ratios:
|
784 |
+
tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
|
785 |
+
pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
|
786 |
+
pools.append(rearrange(pool, 'nl c h w -> nl c (h w)')) # nl(4),c,hw
|
787 |
+
|
788 |
+
pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
|
789 |
+
loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
|
790 |
+
|
791 |
+
outputs = []
|
792 |
+
for i, q in enumerate(loc_.unbind(dim=0)): # traverse all local patches
|
793 |
+
v = pools[i]
|
794 |
+
k = v
|
795 |
+
outputs.append(self.attention[i](q, k, v)[0])
|
796 |
+
|
797 |
+
outputs = torch.cat(outputs, 1)
|
798 |
+
src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
|
799 |
+
src = self.norm1(src)
|
800 |
+
src = src + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(src)).clone())))
|
801 |
+
src = self.norm2(src)
|
802 |
+
src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
|
803 |
+
glb = glb + F.interpolate(patches2image(src), size=glb.shape[-2:], mode='nearest') # freshed glb
|
804 |
+
|
805 |
+
return torch.cat((src, glb), 0), token_attention_map
|
806 |
+
|
807 |
+
|
808 |
+
class BEN_Base(nn.Module):
|
809 |
+
def __init__(self):
|
810 |
+
super().__init__()
|
811 |
+
|
812 |
+
self.backbone = SwinTransformer(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12)
|
813 |
+
emb_dim = 128
|
814 |
+
self.sideout5 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
815 |
+
self.sideout4 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
816 |
+
self.sideout3 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
817 |
+
self.sideout2 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
818 |
+
self.sideout1 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
819 |
+
|
820 |
+
self.output5 = make_cbr(1024, emb_dim)
|
821 |
+
self.output4 = make_cbr(512, emb_dim)
|
822 |
+
self.output3 = make_cbr(256, emb_dim)
|
823 |
+
self.output2 = make_cbr(128, emb_dim)
|
824 |
+
self.output1 = make_cbr(128, emb_dim)
|
825 |
+
|
826 |
+
self.multifieldcrossatt = MCLM(emb_dim, 1, [1, 4, 8])
|
827 |
+
self.conv1 = make_cbr(emb_dim, emb_dim)
|
828 |
+
self.conv2 = make_cbr(emb_dim, emb_dim)
|
829 |
+
self.conv3 = make_cbr(emb_dim, emb_dim)
|
830 |
+
self.conv4 = make_cbr(emb_dim, emb_dim)
|
831 |
+
self.dec_blk1 = MCRM(emb_dim, 1, [2, 4, 8])
|
832 |
+
self.dec_blk2 = MCRM(emb_dim, 1, [2, 4, 8])
|
833 |
+
self.dec_blk3 = MCRM(emb_dim, 1, [2, 4, 8])
|
834 |
+
self.dec_blk4 = MCRM(emb_dim, 1, [2, 4, 8])
|
835 |
+
|
836 |
+
self.insmask_head = nn.Sequential(
|
837 |
+
nn.Conv2d(emb_dim, 384, kernel_size=3, padding=1),
|
838 |
+
nn.InstanceNorm2d(384),
|
839 |
+
nn.GELU(),
|
840 |
+
nn.Conv2d(384, 384, kernel_size=3, padding=1),
|
841 |
+
nn.InstanceNorm2d(384),
|
842 |
+
nn.GELU(),
|
843 |
+
nn.Conv2d(384, emb_dim, kernel_size=3, padding=1)
|
844 |
+
)
|
845 |
+
|
846 |
+
self.shallow = nn.Sequential(nn.Conv2d(3, emb_dim, kernel_size=3, padding=1))
|
847 |
+
self.upsample1 = make_cbg(emb_dim, emb_dim)
|
848 |
+
self.upsample2 = make_cbg(emb_dim, emb_dim)
|
849 |
+
self.output = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
850 |
+
|
851 |
+
for m in self.modules():
|
852 |
+
if isinstance(m, nn.GELU) or isinstance(m, nn.Dropout):
|
853 |
+
m.inplace = True
|
854 |
+
|
855 |
+
def forward(self, x):
|
856 |
+
device = x.device
|
857 |
+
shallow = self.shallow(x)
|
858 |
+
glb = rescale_to(x, scale_factor=0.5, interpolation='bilinear')
|
859 |
+
loc = image2patches(x)
|
860 |
+
input = torch.cat((loc, glb), dim=0)
|
861 |
+
feature = self.backbone(input)
|
862 |
+
e5 = self.output5(feature[4]) # (5,128,16,16)
|
863 |
+
e4 = self.output4(feature[3]) # (5,128,32,32)
|
864 |
+
e3 = self.output3(feature[2]) # (5,128,64,64)
|
865 |
+
e2 = self.output2(feature[1]) # (5,128,128,128)
|
866 |
+
e1 = self.output1(feature[0]) # (5,128,128,128)
|
867 |
+
loc_e5, glb_e5 = e5.split([4, 1], dim=0)
|
868 |
+
e5 = self.multifieldcrossatt(loc_e5, glb_e5) # (4,128,16,16)
|
869 |
+
|
870 |
+
e4, tokenattmap4 = self.dec_blk4(e4 + resize_as(e5, e4))
|
871 |
+
e4 = self.conv4(e4)
|
872 |
+
e3, tokenattmap3 = self.dec_blk3(e3 + resize_as(e4, e3))
|
873 |
+
e3 = self.conv3(e3)
|
874 |
+
e2, tokenattmap2 = self.dec_blk2(e2 + resize_as(e3, e2))
|
875 |
+
e2 = self.conv2(e2)
|
876 |
+
e1, tokenattmap1 = self.dec_blk1(e1 + resize_as(e2, e1))
|
877 |
+
e1 = self.conv1(e1)
|
878 |
+
loc_e1, glb_e1 = e1.split([4, 1], dim=0)
|
879 |
+
output1_cat = patches2image(loc_e1) # (1,128,256,256)
|
880 |
+
output1_cat = output1_cat + resize_as(glb_e1, output1_cat)
|
881 |
+
final_output = self.insmask_head(output1_cat) # (1,128,256,256)
|
882 |
+
final_output = final_output + resize_as(shallow, final_output)
|
883 |
+
final_output = self.upsample1(rescale_to(final_output))
|
884 |
+
final_output = rescale_to(final_output + resize_as(shallow, final_output))
|
885 |
+
final_output = self.upsample2(final_output)
|
886 |
+
final_output = self.output(final_output)
|
887 |
+
|
888 |
+
return final_output.sigmoid()
|
889 |
+
|
890 |
+
@torch.no_grad()
|
891 |
+
def inference(self,image):
|
892 |
+
image, h, w,original_image = rgb_loader_refiner(image)
|
893 |
+
|
894 |
+
img_tensor = img_transform(image).unsqueeze(0).to(next(self.parameters()).device)
|
895 |
+
|
896 |
+
res = self.forward(img_tensor)
|
897 |
+
|
898 |
+
pred_array = postprocess_image(res, im_size=[w, h])
|
899 |
+
|
900 |
+
mask_image = Image.fromarray(pred_array, mode='L')
|
901 |
+
|
902 |
+
blurred_mask = mask_image.filter(ImageFilter.GaussianBlur(radius=1))
|
903 |
+
|
904 |
+
original_image_rgba = original_image.convert("RGBA")
|
905 |
+
|
906 |
+
foreground = original_image_rgba.copy()
|
907 |
+
|
908 |
+
foreground.putalpha(blurred_mask)
|
909 |
+
|
910 |
+
return blurred_mask, foreground
|
911 |
+
|
912 |
+
def loadcheckpoints(self,model_path):
|
913 |
+
model_dict = torch.load(model_path, map_location="cpu", weights_only=True)
|
914 |
+
self.load_state_dict(model_dict['model_state_dict'], strict=True)
|
915 |
+
del model_path
|
916 |
+
|
917 |
+
|
918 |
+
|
919 |
+
|
920 |
+
def rgb_loader_refiner( original_image):
|
921 |
+
h, w = original_image.size
|
922 |
+
# # Apply EXIF orientation
|
923 |
+
image = ImageOps.exif_transpose(original_image)
|
924 |
+
# Convert to RGB if necessary
|
925 |
+
if image.mode != 'RGB':
|
926 |
+
image = image.convert('RGB')
|
927 |
+
|
928 |
+
# Resize the image
|
929 |
+
image = image.resize((1024, 1024), resample=Image.LANCZOS)
|
930 |
+
|
931 |
+
return image.convert('RGB'), h, w,original_image
|
932 |
+
|
933 |
+
# Define the image transformation
|
934 |
+
img_transform = transforms.Compose([
|
935 |
+
transforms.ToTensor(),
|
936 |
+
transforms.ConvertImageDtype(torch.float32),
|
937 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
938 |
+
])
|
939 |
+
|
940 |
+
def postprocess_image(result: torch.Tensor, im_size: list) -> np.ndarray:
|
941 |
+
result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear'), 0)
|
942 |
+
ma = torch.max(result)
|
943 |
+
mi = torch.min(result)
|
944 |
+
result = (result - mi) / (ma - mi)
|
945 |
+
im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
|
946 |
+
im_array = np.squeeze(im_array)
|
947 |
+
return im_array
|
948 |
+
|
949 |
+
|
950 |
+
|
951 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy>=1.21.0
|
2 |
+
torch>=1.9.0
|
3 |
+
einops>=0.6.0
|
4 |
+
Pillow>=9.0.0
|
5 |
+
timm>=0.6.0
|
6 |
+
torchvision>=0.10.0
|