lastest version
Browse files- .gitattributes +31 -31
- README.md +13 -13
- app.py +32 -32
- example/Standing_jaguar.jpg +0 -0
- gcvit/__init__.py +1 -1
- gcvit/__pycache__/__init__.cpython-38.pyc +0 -0
- gcvit/layers/__init__.py +7 -7
- gcvit/layers/__pycache__/__init__.cpython-38.pyc +0 -0
- gcvit/layers/__pycache__/attention.cpython-38.pyc +0 -0
- gcvit/layers/__pycache__/block.cpython-38.pyc +0 -0
- gcvit/layers/__pycache__/drop.cpython-38.pyc +0 -0
- gcvit/layers/__pycache__/embedding.cpython-38.pyc +0 -0
- gcvit/layers/__pycache__/feature.cpython-38.pyc +0 -0
- gcvit/layers/__pycache__/level.cpython-38.pyc +0 -0
- gcvit/layers/__pycache__/window.cpython-38.pyc +0 -0
- gcvit/layers/block.py +98 -98
- gcvit/layers/embedding.py +1 -1
- gcvit/layers/feature.py +254 -201
- gcvit/layers/level.py +84 -92
- gcvit/models/__init__.py +1 -1
- gcvit/models/__pycache__/__init__.cpython-38.pyc +0 -0
- gcvit/models/__pycache__/gcvit.cpython-38.pyc +0 -0
- gcvit/models/gcvit.py +180 -145
- gcvit/utils/gradcam.py +68 -68
- gcvit/version.py +1 -1
- requirements.txt +4 -4
- setup.py +49 -49
.gitattributes
CHANGED
@@ -1,31 +1,31 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
23 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
26 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
23 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
-
---
|
2 |
-
title: Gcvit Tf
|
3 |
-
emoji: 📈
|
4 |
-
colorFrom: yellow
|
5 |
-
colorTo: purple
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 3.1.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
license: mit
|
11 |
-
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
+
---
|
2 |
+
title: Gcvit Tf
|
3 |
+
emoji: 📈
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: purple
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.1.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: mit
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
CHANGED
@@ -1,33 +1,33 @@
|
|
1 |
-
import tensorflow as tf
|
2 |
-
import gradio as gr
|
3 |
-
import gcvit
|
4 |
-
from gcvit.utils import get_gradcam_model, get_gradcam_prediction
|
5 |
-
|
6 |
-
def predict_fn(image, model_name):
|
7 |
-
"""A predict function that will be invoked by gradio."""
|
8 |
-
model = getattr(gcvit, model_name)(pretrain=True)
|
9 |
-
gradcam_model = get_gradcam_model(model)
|
10 |
-
preds, overlay = get_gradcam_prediction(image, gradcam_model, cmap='jet', alpha=0.4, pred_index=None)
|
11 |
-
preds = {x[1]:float(x[2]) for x in preds}
|
12 |
-
return [preds, overlay]
|
13 |
-
|
14 |
-
demo = gr.Interface(
|
15 |
-
fn=predict_fn,
|
16 |
-
inputs=[
|
17 |
-
gr.inputs.Image(label="Input Image"),
|
18 |
-
gr.Radio(['GCViTTiny', 'GCViTSmall', 'GCViTBase'], value='GCViTTiny', label='Model Name')
|
19 |
-
],
|
20 |
-
outputs=[
|
21 |
-
gr.outputs.Label(label="Prediction"),
|
22 |
-
gr.inputs.Image(label="GradCAM"),
|
23 |
-
],
|
24 |
-
title="Global Context Vision Transformer (GCViT) Demo",
|
25 |
-
description="Image Classification with GCViT Model using ImageNet Pretrain Weights.",
|
26 |
-
examples=[
|
27 |
-
["example/hot_air_ballon.jpg", 'GCViTTiny'],
|
28 |
-
["example/chelsea.png", 'GCViTTiny'],
|
29 |
-
["example/penguin.JPG", 'GCViTTiny'],
|
30 |
-
["example/bus.jpg", 'GCViTTiny'],
|
31 |
-
],
|
32 |
-
)
|
33 |
demo.launch()
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import gradio as gr
|
3 |
+
import gcvit
|
4 |
+
from gcvit.utils import get_gradcam_model, get_gradcam_prediction
|
5 |
+
|
6 |
+
def predict_fn(image, model_name):
|
7 |
+
"""A predict function that will be invoked by gradio."""
|
8 |
+
model = getattr(gcvit, model_name)(pretrain=True)
|
9 |
+
gradcam_model = get_gradcam_model(model)
|
10 |
+
preds, overlay = get_gradcam_prediction(image, gradcam_model, cmap='jet', alpha=0.4, pred_index=None)
|
11 |
+
preds = {x[1]:float(x[2]) for x in preds}
|
12 |
+
return [preds, overlay]
|
13 |
+
|
14 |
+
demo = gr.Interface(
|
15 |
+
fn=predict_fn,
|
16 |
+
inputs=[
|
17 |
+
gr.inputs.Image(label="Input Image"),
|
18 |
+
gr.Radio(['GCViTTiny', 'GCViTSmall', 'GCViTBase'], value='GCViTTiny', label='Model Name')
|
19 |
+
],
|
20 |
+
outputs=[
|
21 |
+
gr.outputs.Label(label="Prediction"),
|
22 |
+
gr.inputs.Image(label="GradCAM"),
|
23 |
+
],
|
24 |
+
title="Global Context Vision Transformer (GCViT) Demo",
|
25 |
+
description="Image Classification with GCViT Model using ImageNet Pretrain Weights.",
|
26 |
+
examples=[
|
27 |
+
["example/hot_air_ballon.jpg", 'GCViTTiny'],
|
28 |
+
["example/chelsea.png", 'GCViTTiny'],
|
29 |
+
["example/penguin.JPG", 'GCViTTiny'],
|
30 |
+
["example/bus.jpg", 'GCViTTiny'],
|
31 |
+
],
|
32 |
+
)
|
33 |
demo.launch()
|
example/Standing_jaguar.jpg
ADDED
gcvit/__init__.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1 |
-
from .models import GCViT, GCViTTiny, GCViTSmall, GCViTBase
|
2 |
from .version import __version__
|
|
|
1 |
+
from .models import GCViT, GCViTXXTiny, GCViTXTiny, GCViTTiny, GCViTSmall, GCViTBase
|
2 |
from .version import __version__
|
gcvit/__pycache__/__init__.cpython-38.pyc
DELETED
Binary file (228 Bytes)
|
|
gcvit/layers/__init__.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
-
from .window import window_partition, window_reverse
|
2 |
-
from .attention import WindowAttention
|
3 |
-
from .drop import DropPath, Identity
|
4 |
-
from .embedding import
|
5 |
-
from .feature import Mlp, FeatExtract, ReduceSize, SE, Resizing
|
6 |
-
from .block import GCViTBlock
|
7 |
-
from .level import
|
|
|
1 |
+
from .window import window_partition, window_reverse
|
2 |
+
from .attention import WindowAttention
|
3 |
+
from .drop import DropPath, Identity
|
4 |
+
from .embedding import Stem
|
5 |
+
from .feature import Mlp, FeatExtract, ReduceSize, SE, Resizing
|
6 |
+
from .block import GCViTBlock
|
7 |
+
from .level import GCViTLevel
|
gcvit/layers/__pycache__/__init__.cpython-38.pyc
DELETED
Binary file (530 Bytes)
|
|
gcvit/layers/__pycache__/attention.cpython-38.pyc
DELETED
Binary file (3.58 kB)
|
|
gcvit/layers/__pycache__/block.cpython-38.pyc
DELETED
Binary file (3 kB)
|
|
gcvit/layers/__pycache__/drop.cpython-38.pyc
DELETED
Binary file (1.8 kB)
|
|
gcvit/layers/__pycache__/embedding.cpython-38.pyc
DELETED
Binary file (1.39 kB)
|
|
gcvit/layers/__pycache__/feature.cpython-38.pyc
DELETED
Binary file (5.5 kB)
|
|
gcvit/layers/__pycache__/level.cpython-38.pyc
DELETED
Binary file (3 kB)
|
|
gcvit/layers/__pycache__/window.cpython-38.pyc
DELETED
Binary file (801 Bytes)
|
|
gcvit/layers/block.py
CHANGED
@@ -1,99 +1,99 @@
|
|
1 |
-
import tensorflow as tf
|
2 |
-
|
3 |
-
from .attention import WindowAttention
|
4 |
-
from .drop import DropPath
|
5 |
-
from .window import window_partition, window_reverse
|
6 |
-
from .feature import Mlp, FeatExtract
|
7 |
-
|
8 |
-
|
9 |
-
@tf.keras.utils.register_keras_serializable(package="gcvit")
|
10 |
-
class GCViTBlock(tf.keras.layers.Layer):
|
11 |
-
def __init__(self, window_size, num_heads, global_query, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0.,
|
12 |
-
attn_drop=0., path_drop=0., act_layer='gelu', layer_scale=None, **kwargs):
|
13 |
-
super().__init__(**kwargs)
|
14 |
-
self.window_size = window_size
|
15 |
-
self.num_heads = num_heads
|
16 |
-
self.global_query = global_query
|
17 |
-
self.mlp_ratio = mlp_ratio
|
18 |
-
self.qkv_bias = qkv_bias
|
19 |
-
self.qk_scale = qk_scale
|
20 |
-
self.drop = drop
|
21 |
-
self.attn_drop = attn_drop
|
22 |
-
self.path_drop = path_drop
|
23 |
-
self.act_layer = act_layer
|
24 |
-
self.layer_scale = layer_scale
|
25 |
-
|
26 |
-
def build(self, input_shape):
|
27 |
-
B, H, W, C = input_shape[0]
|
28 |
-
self.norm1 = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-05, name='norm1')
|
29 |
-
self.attn = WindowAttention(window_size=self.window_size,
|
30 |
-
num_heads=self.num_heads,
|
31 |
-
global_query=self.global_query,
|
32 |
-
qkv_bias=self.qkv_bias,
|
33 |
-
qk_scale=self.qk_scale,
|
34 |
-
attn_dropout=self.attn_drop,
|
35 |
-
proj_dropout=self.drop,
|
36 |
-
name='attn')
|
37 |
-
self.drop_path1 = DropPath(self.path_drop)
|
38 |
-
self.drop_path2 = DropPath(self.path_drop)
|
39 |
-
self.norm2 = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-05, name='norm2')
|
40 |
-
self.mlp = Mlp(hidden_features=int(C * self.mlp_ratio), dropout=self.drop, act_layer=self.act_layer, name='mlp')
|
41 |
-
if self.layer_scale is not None:
|
42 |
-
self.gamma1 = self.add_weight(
|
43 |
-
'gamma1',
|
44 |
-
shape=[C],
|
45 |
-
initializer=tf.keras.initializers.Constant(self.layer_scale),
|
46 |
-
trainable=True,
|
47 |
-
dtype=self.dtype)
|
48 |
-
self.gamma2 = self.add_weight(
|
49 |
-
'gamma2',
|
50 |
-
shape=[C],
|
51 |
-
initializer=tf.keras.initializers.Constant(self.layer_scale),
|
52 |
-
trainable=True,
|
53 |
-
dtype=self.dtype)
|
54 |
-
else:
|
55 |
-
self.gamma1 = 1.0
|
56 |
-
self.gamma2 = 1.0
|
57 |
-
self.num_windows = int(H // self.window_size) * int(W // self.window_size)
|
58 |
-
super().build(input_shape)
|
59 |
-
|
60 |
-
def call(self, inputs, **kwargs):
|
61 |
-
if self.global_query:
|
62 |
-
inputs, q_global = inputs
|
63 |
-
else:
|
64 |
-
inputs = inputs[0]
|
65 |
-
B, H, W, C = tf.unstack(tf.shape(inputs), num=4)
|
66 |
-
x = self.norm1(inputs)
|
67 |
-
# create windows and concat them in batch axis
|
68 |
-
x = window_partition(x, self.window_size) # (B_, win_h, win_w, C)
|
69 |
-
# flatten patch
|
70 |
-
x = tf.reshape(x, shape=[-1, self.window_size * self.window_size, C]) # (B_, N, C) => (batch*num_win, num_token, feature)
|
71 |
-
# attention
|
72 |
-
if self.global_query:
|
73 |
-
x = self.attn([x, q_global])
|
74 |
-
else:
|
75 |
-
x = self.attn([x])
|
76 |
-
# reverse window partition
|
77 |
-
x = window_reverse(x, self.window_size, H, W, C)
|
78 |
-
# FFN
|
79 |
-
x = inputs + self.drop_path1(x * self.gamma1)
|
80 |
-
x = x + self.drop_path2(self.gamma2 * self.mlp(self.norm2(x)))
|
81 |
-
return x
|
82 |
-
|
83 |
-
def get_config(self):
|
84 |
-
config = super().get_config()
|
85 |
-
config.update({
|
86 |
-
'window_size': self.window_size,
|
87 |
-
'num_heads': self.num_heads,
|
88 |
-
'global_query': self.global_query,
|
89 |
-
'mlp_ratio': self.mlp_ratio,
|
90 |
-
'qkv_bias': self.qkv_bias,
|
91 |
-
'qk_scale': self.qk_scale,
|
92 |
-
'drop': self.drop,
|
93 |
-
'attn_drop': self.attn_drop,
|
94 |
-
'path_drop': self.path_drop,
|
95 |
-
'act_layer': self.act_layer,
|
96 |
-
'layer_scale': self.layer_scale,
|
97 |
-
'num_windows': self.num_windows,
|
98 |
-
})
|
99 |
return config
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
|
3 |
+
from .attention import WindowAttention
|
4 |
+
from .drop import DropPath
|
5 |
+
from .window import window_partition, window_reverse
|
6 |
+
from .feature import Mlp, FeatExtract
|
7 |
+
|
8 |
+
|
9 |
+
@tf.keras.utils.register_keras_serializable(package="gcvit")
|
10 |
+
class GCViTBlock(tf.keras.layers.Layer):
|
11 |
+
def __init__(self, window_size, num_heads, global_query, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0.,
|
12 |
+
attn_drop=0., path_drop=0., act_layer='gelu', layer_scale=None, **kwargs):
|
13 |
+
super().__init__(**kwargs)
|
14 |
+
self.window_size = window_size
|
15 |
+
self.num_heads = num_heads
|
16 |
+
self.global_query = global_query
|
17 |
+
self.mlp_ratio = mlp_ratio
|
18 |
+
self.qkv_bias = qkv_bias
|
19 |
+
self.qk_scale = qk_scale
|
20 |
+
self.drop = drop
|
21 |
+
self.attn_drop = attn_drop
|
22 |
+
self.path_drop = path_drop
|
23 |
+
self.act_layer = act_layer
|
24 |
+
self.layer_scale = layer_scale
|
25 |
+
|
26 |
+
def build(self, input_shape):
|
27 |
+
B, H, W, C = input_shape[0]
|
28 |
+
self.norm1 = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-05, name='norm1')
|
29 |
+
self.attn = WindowAttention(window_size=self.window_size,
|
30 |
+
num_heads=self.num_heads,
|
31 |
+
global_query=self.global_query,
|
32 |
+
qkv_bias=self.qkv_bias,
|
33 |
+
qk_scale=self.qk_scale,
|
34 |
+
attn_dropout=self.attn_drop,
|
35 |
+
proj_dropout=self.drop,
|
36 |
+
name='attn')
|
37 |
+
self.drop_path1 = DropPath(self.path_drop)
|
38 |
+
self.drop_path2 = DropPath(self.path_drop)
|
39 |
+
self.norm2 = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-05, name='norm2')
|
40 |
+
self.mlp = Mlp(hidden_features=int(C * self.mlp_ratio), dropout=self.drop, act_layer=self.act_layer, name='mlp')
|
41 |
+
if self.layer_scale is not None:
|
42 |
+
self.gamma1 = self.add_weight(
|
43 |
+
'gamma1',
|
44 |
+
shape=[C],
|
45 |
+
initializer=tf.keras.initializers.Constant(self.layer_scale),
|
46 |
+
trainable=True,
|
47 |
+
dtype=self.dtype)
|
48 |
+
self.gamma2 = self.add_weight(
|
49 |
+
'gamma2',
|
50 |
+
shape=[C],
|
51 |
+
initializer=tf.keras.initializers.Constant(self.layer_scale),
|
52 |
+
trainable=True,
|
53 |
+
dtype=self.dtype)
|
54 |
+
else:
|
55 |
+
self.gamma1 = 1.0
|
56 |
+
self.gamma2 = 1.0
|
57 |
+
self.num_windows = int(H // self.window_size) * int(W // self.window_size)
|
58 |
+
super().build(input_shape)
|
59 |
+
|
60 |
+
def call(self, inputs, **kwargs):
|
61 |
+
if self.global_query:
|
62 |
+
inputs, q_global = inputs
|
63 |
+
else:
|
64 |
+
inputs = inputs[0]
|
65 |
+
B, H, W, C = tf.unstack(tf.shape(inputs), num=4)
|
66 |
+
x = self.norm1(inputs)
|
67 |
+
# create windows and concat them in batch axis
|
68 |
+
x = window_partition(x, self.window_size) # (B_, win_h, win_w, C)
|
69 |
+
# flatten patch
|
70 |
+
x = tf.reshape(x, shape=[-1, self.window_size * self.window_size, C]) # (B_, N, C) => (batch*num_win, num_token, feature)
|
71 |
+
# attention
|
72 |
+
if self.global_query:
|
73 |
+
x = self.attn([x, q_global])
|
74 |
+
else:
|
75 |
+
x = self.attn([x])
|
76 |
+
# reverse window partition
|
77 |
+
x = window_reverse(x, self.window_size, H, W, C)
|
78 |
+
# FFN
|
79 |
+
x = inputs + self.drop_path1(x * self.gamma1)
|
80 |
+
x = x + self.drop_path2(self.gamma2 * self.mlp(self.norm2(x)))
|
81 |
+
return x
|
82 |
+
|
83 |
+
def get_config(self):
|
84 |
+
config = super().get_config()
|
85 |
+
config.update({
|
86 |
+
'window_size': self.window_size,
|
87 |
+
'num_heads': self.num_heads,
|
88 |
+
'global_query': self.global_query,
|
89 |
+
'mlp_ratio': self.mlp_ratio,
|
90 |
+
'qkv_bias': self.qkv_bias,
|
91 |
+
'qk_scale': self.qk_scale,
|
92 |
+
'drop': self.drop,
|
93 |
+
'attn_drop': self.attn_drop,
|
94 |
+
'path_drop': self.path_drop,
|
95 |
+
'act_layer': self.act_layer,
|
96 |
+
'layer_scale': self.layer_scale,
|
97 |
+
'num_windows': self.num_windows,
|
98 |
+
})
|
99 |
return config
|
gcvit/layers/embedding.py
CHANGED
@@ -4,7 +4,7 @@ from .feature import ReduceSize
|
|
4 |
|
5 |
|
6 |
@tf.keras.utils.register_keras_serializable(package="gcvit")
|
7 |
-
class
|
8 |
def __init__(self, dim, **kwargs):
|
9 |
super().__init__(**kwargs)
|
10 |
self.dim = dim
|
|
|
4 |
|
5 |
|
6 |
@tf.keras.utils.register_keras_serializable(package="gcvit")
|
7 |
+
class Stem(tf.keras.layers.Layer):
|
8 |
def __init__(self, dim, **kwargs):
|
9 |
super().__init__(**kwargs)
|
10 |
self.dim = dim
|
gcvit/layers/feature.py
CHANGED
@@ -1,202 +1,255 @@
|
|
1 |
-
import tensorflow as tf
|
2 |
-
import tensorflow_addons as tfa
|
3 |
-
|
4 |
-
H_AXIS = -3
|
5 |
-
W_AXIS = -2
|
6 |
-
|
7 |
-
@tf.keras.utils.register_keras_serializable(package="gcvit")
|
8 |
-
class Mlp(tf.keras.layers.Layer):
|
9 |
-
def __init__(self, hidden_features=None, out_features=None, act_layer='gelu', dropout=0., **kwargs):
|
10 |
-
super().__init__(**kwargs)
|
11 |
-
self.hidden_features = hidden_features
|
12 |
-
self.out_features = out_features
|
13 |
-
self.act_layer = act_layer
|
14 |
-
self.dropout = dropout
|
15 |
-
|
16 |
-
def build(self, input_shape):
|
17 |
-
self.in_features = input_shape[-1]
|
18 |
-
self.hidden_features = self.hidden_features or self.in_features
|
19 |
-
self.out_features = self.out_features or self.in_features
|
20 |
-
self.fc1 = tf.keras.layers.Dense(self.hidden_features, name="fc1")
|
21 |
-
self.act = tf.keras.layers.Activation(self.act_layer, name="act")
|
22 |
-
self.fc2 = tf.keras.layers.Dense(self.out_features, name="fc2")
|
23 |
-
self.drop1 = tf.keras.layers.Dropout(self.dropout, name="drop1")
|
24 |
-
self.drop2 = tf.keras.layers.Dropout(self.dropout, name="drop2")
|
25 |
-
super().build(input_shape)
|
26 |
-
|
27 |
-
def call(self, inputs, **kwargs):
|
28 |
-
x = self.fc1(inputs)
|
29 |
-
x = self.act(x)
|
30 |
-
x = self.drop1(x)
|
31 |
-
x = self.fc2(x)
|
32 |
-
x = self.drop2(x)
|
33 |
-
return x
|
34 |
-
|
35 |
-
def get_config(self):
|
36 |
-
config = super().get_config()
|
37 |
-
config.update({
|
38 |
-
"hidden_features":self.hidden_features,
|
39 |
-
"out_features":self.out_features,
|
40 |
-
"act_layer":self.act_layer,
|
41 |
-
"dropout":self.dropout
|
42 |
-
})
|
43 |
-
return config
|
44 |
-
|
45 |
-
@tf.keras.utils.register_keras_serializable(package="gcvit")
|
46 |
-
class SE(tf.keras.layers.Layer):
|
47 |
-
def __init__(self, oup=None, expansion=0.25, **kwargs):
|
48 |
-
super().__init__(**kwargs)
|
49 |
-
self.expansion = expansion
|
50 |
-
self.oup = oup
|
51 |
-
|
52 |
-
def build(self, input_shape):
|
53 |
-
inp = input_shape[-1]
|
54 |
-
self.oup = self.oup or inp
|
55 |
-
self.avg_pool = tfa.layers.AdaptiveAveragePooling2D(1, name="avg_pool")
|
56 |
-
self.fc = [
|
57 |
-
tf.keras.layers.Dense(int(inp * self.expansion), use_bias=False, name='fc/0'),
|
58 |
-
tf.keras.layers.Activation('gelu', name='fc/1'),
|
59 |
-
tf.keras.layers.Dense(self.oup, use_bias=False, name='fc/2'),
|
60 |
-
tf.keras.layers.Activation('sigmoid', name='fc/3')
|
61 |
-
]
|
62 |
-
super().build(input_shape)
|
63 |
-
|
64 |
-
def call(self, inputs, **kwargs):
|
65 |
-
b, _, _, c = tf.unstack(tf.shape(inputs), num=4)
|
66 |
-
x = tf.reshape(self.avg_pool(inputs), (b, c))
|
67 |
-
for layer in self.fc:
|
68 |
-
x = layer(x)
|
69 |
-
x = tf.reshape(x, (b, 1, 1, c))
|
70 |
-
return x*inputs
|
71 |
-
|
72 |
-
def get_config(self):
|
73 |
-
config = super().get_config()
|
74 |
-
config.update({
|
75 |
-
'expansion': self.expansion,
|
76 |
-
'oup': self.oup,
|
77 |
-
})
|
78 |
-
return config
|
79 |
-
|
80 |
-
@tf.keras.utils.register_keras_serializable(package="gcvit")
|
81 |
-
class ReduceSize(tf.keras.layers.Layer):
|
82 |
-
def __init__(self, keep_dim=False, **kwargs):
|
83 |
-
super().__init__(**kwargs)
|
84 |
-
self.keep_dim = keep_dim
|
85 |
-
|
86 |
-
def build(self, input_shape):
|
87 |
-
dim = input_shape[-1]
|
88 |
-
dim_out = dim if self.keep_dim else 2*dim
|
89 |
-
self.pad1 = tf.keras.layers.ZeroPadding2D(1, name='pad1')
|
90 |
-
self.pad2 = tf.keras.layers.ZeroPadding2D(1, name='pad2')
|
91 |
-
self.conv = [
|
92 |
-
tf.keras.layers.DepthwiseConv2D(kernel_size=3, strides=1, padding='valid', use_bias=False, name='conv/0'),
|
93 |
-
tf.keras.layers.Activation('gelu', name='conv/1'),
|
94 |
-
SE(name='conv/2'),
|
95 |
-
tf.keras.layers.Conv2D(dim, kernel_size=1, strides=1, padding='valid', use_bias=False, name='conv/3')
|
96 |
-
]
|
97 |
-
self.reduction = tf.keras.layers.Conv2D(dim_out, kernel_size=3, strides=2, padding='valid', use_bias=False,
|
98 |
-
name='reduction')
|
99 |
-
self.norm1 = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-05, name='norm1') # eps like PyTorch
|
100 |
-
self.norm2 = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-05, name='norm2')
|
101 |
-
super().build(input_shape)
|
102 |
-
|
103 |
-
def call(self, inputs, **kwargs):
|
104 |
-
x = self.norm1(inputs)
|
105 |
-
xr = self.pad1(x) # if pad had weights it would've thrown error with .save_weights()
|
106 |
-
for layer in self.conv:
|
107 |
-
xr = layer(xr)
|
108 |
-
x = x + xr
|
109 |
-
x = self.pad2(x)
|
110 |
-
x = self.reduction(x)
|
111 |
-
x = self.norm2(x)
|
112 |
-
return x
|
113 |
-
|
114 |
-
def get_config(self):
|
115 |
-
config = super().get_config()
|
116 |
-
config.update({
|
117 |
-
"keep_dim":self.keep_dim,
|
118 |
-
})
|
119 |
-
return config
|
120 |
-
|
121 |
-
@tf.keras.utils.register_keras_serializable(package="gcvit")
|
122 |
-
class FeatExtract(tf.keras.layers.Layer):
|
123 |
-
def __init__(self, keep_dim=False, **kwargs):
|
124 |
-
super().__init__(**kwargs)
|
125 |
-
self.keep_dim = keep_dim
|
126 |
-
|
127 |
-
def build(self, input_shape):
|
128 |
-
dim = input_shape[-1]
|
129 |
-
self.pad1 = tf.keras.layers.ZeroPadding2D(1, name='pad1')
|
130 |
-
self.pad2 = tf.keras.layers.ZeroPadding2D(1, name='pad2')
|
131 |
-
self.conv = [
|
132 |
-
tf.keras.layers.DepthwiseConv2D(kernel_size=3, strides=1, padding='valid', use_bias=False, name='conv/0'),
|
133 |
-
tf.keras.layers.Activation('gelu', name='conv/1'),
|
134 |
-
SE(name='conv/2'),
|
135 |
-
tf.keras.layers.Conv2D(dim, kernel_size=1, strides=1, padding='valid', use_bias=False, name='conv/3')
|
136 |
-
]
|
137 |
-
if not self.keep_dim:
|
138 |
-
self.pool = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding='valid', name='pool')
|
139 |
-
# else:
|
140 |
-
# self.pool = tf.keras.layers.Activation('linear', name='identity') # hack for PyTorch nn.Identity layer ;)
|
141 |
-
super().build(input_shape)
|
142 |
-
|
143 |
-
def call(self, inputs, **kwargs):
|
144 |
-
x = inputs
|
145 |
-
xr = self.pad1(x)
|
146 |
-
for layer in self.conv:
|
147 |
-
xr = layer(xr)
|
148 |
-
x = x + xr # if pad had weights it would've thrown error with .save_weights()
|
149 |
-
if not self.keep_dim:
|
150 |
-
x = self.pad2(x)
|
151 |
-
x = self.pool(x)
|
152 |
-
return x
|
153 |
-
|
154 |
-
def get_config(self):
|
155 |
-
config = super().get_config()
|
156 |
-
config.update({
|
157 |
-
"keep_dim":self.keep_dim,
|
158 |
-
})
|
159 |
-
return config
|
160 |
-
|
161 |
-
@tf.keras.utils.register_keras_serializable(package="gcvit")
|
162 |
-
class
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
self.
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
return
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
return config
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import tensorflow_addons as tfa
|
3 |
+
|
4 |
+
H_AXIS = -3
|
5 |
+
W_AXIS = -2
|
6 |
+
|
7 |
+
@tf.keras.utils.register_keras_serializable(package="gcvit")
|
8 |
+
class Mlp(tf.keras.layers.Layer):
|
9 |
+
def __init__(self, hidden_features=None, out_features=None, act_layer='gelu', dropout=0., **kwargs):
|
10 |
+
super().__init__(**kwargs)
|
11 |
+
self.hidden_features = hidden_features
|
12 |
+
self.out_features = out_features
|
13 |
+
self.act_layer = act_layer
|
14 |
+
self.dropout = dropout
|
15 |
+
|
16 |
+
def build(self, input_shape):
|
17 |
+
self.in_features = input_shape[-1]
|
18 |
+
self.hidden_features = self.hidden_features or self.in_features
|
19 |
+
self.out_features = self.out_features or self.in_features
|
20 |
+
self.fc1 = tf.keras.layers.Dense(self.hidden_features, name="fc1")
|
21 |
+
self.act = tf.keras.layers.Activation(self.act_layer, name="act")
|
22 |
+
self.fc2 = tf.keras.layers.Dense(self.out_features, name="fc2")
|
23 |
+
self.drop1 = tf.keras.layers.Dropout(self.dropout, name="drop1")
|
24 |
+
self.drop2 = tf.keras.layers.Dropout(self.dropout, name="drop2")
|
25 |
+
super().build(input_shape)
|
26 |
+
|
27 |
+
def call(self, inputs, **kwargs):
|
28 |
+
x = self.fc1(inputs)
|
29 |
+
x = self.act(x)
|
30 |
+
x = self.drop1(x)
|
31 |
+
x = self.fc2(x)
|
32 |
+
x = self.drop2(x)
|
33 |
+
return x
|
34 |
+
|
35 |
+
def get_config(self):
|
36 |
+
config = super().get_config()
|
37 |
+
config.update({
|
38 |
+
"hidden_features":self.hidden_features,
|
39 |
+
"out_features":self.out_features,
|
40 |
+
"act_layer":self.act_layer,
|
41 |
+
"dropout":self.dropout
|
42 |
+
})
|
43 |
+
return config
|
44 |
+
|
45 |
+
@tf.keras.utils.register_keras_serializable(package="gcvit")
|
46 |
+
class SE(tf.keras.layers.Layer):
|
47 |
+
def __init__(self, oup=None, expansion=0.25, **kwargs):
|
48 |
+
super().__init__(**kwargs)
|
49 |
+
self.expansion = expansion
|
50 |
+
self.oup = oup
|
51 |
+
|
52 |
+
def build(self, input_shape):
|
53 |
+
inp = input_shape[-1]
|
54 |
+
self.oup = self.oup or inp
|
55 |
+
self.avg_pool = tfa.layers.AdaptiveAveragePooling2D(1, name="avg_pool")
|
56 |
+
self.fc = [
|
57 |
+
tf.keras.layers.Dense(int(inp * self.expansion), use_bias=False, name='fc/0'),
|
58 |
+
tf.keras.layers.Activation('gelu', name='fc/1'),
|
59 |
+
tf.keras.layers.Dense(self.oup, use_bias=False, name='fc/2'),
|
60 |
+
tf.keras.layers.Activation('sigmoid', name='fc/3')
|
61 |
+
]
|
62 |
+
super().build(input_shape)
|
63 |
+
|
64 |
+
def call(self, inputs, **kwargs):
|
65 |
+
b, _, _, c = tf.unstack(tf.shape(inputs), num=4)
|
66 |
+
x = tf.reshape(self.avg_pool(inputs), (b, c))
|
67 |
+
for layer in self.fc:
|
68 |
+
x = layer(x)
|
69 |
+
x = tf.reshape(x, (b, 1, 1, c))
|
70 |
+
return x*inputs
|
71 |
+
|
72 |
+
def get_config(self):
|
73 |
+
config = super().get_config()
|
74 |
+
config.update({
|
75 |
+
'expansion': self.expansion,
|
76 |
+
'oup': self.oup,
|
77 |
+
})
|
78 |
+
return config
|
79 |
+
|
80 |
+
@tf.keras.utils.register_keras_serializable(package="gcvit")
|
81 |
+
class ReduceSize(tf.keras.layers.Layer):
|
82 |
+
def __init__(self, keep_dim=False, **kwargs):
|
83 |
+
super().__init__(**kwargs)
|
84 |
+
self.keep_dim = keep_dim
|
85 |
+
|
86 |
+
def build(self, input_shape):
|
87 |
+
dim = input_shape[-1]
|
88 |
+
dim_out = dim if self.keep_dim else 2*dim
|
89 |
+
self.pad1 = tf.keras.layers.ZeroPadding2D(1, name='pad1')
|
90 |
+
self.pad2 = tf.keras.layers.ZeroPadding2D(1, name='pad2')
|
91 |
+
self.conv = [
|
92 |
+
tf.keras.layers.DepthwiseConv2D(kernel_size=3, strides=1, padding='valid', use_bias=False, name='conv/0'),
|
93 |
+
tf.keras.layers.Activation('gelu', name='conv/1'),
|
94 |
+
SE(name='conv/2'),
|
95 |
+
tf.keras.layers.Conv2D(dim, kernel_size=1, strides=1, padding='valid', use_bias=False, name='conv/3')
|
96 |
+
]
|
97 |
+
self.reduction = tf.keras.layers.Conv2D(dim_out, kernel_size=3, strides=2, padding='valid', use_bias=False,
|
98 |
+
name='reduction')
|
99 |
+
self.norm1 = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-05, name='norm1') # eps like PyTorch
|
100 |
+
self.norm2 = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-05, name='norm2')
|
101 |
+
super().build(input_shape)
|
102 |
+
|
103 |
+
def call(self, inputs, **kwargs):
|
104 |
+
x = self.norm1(inputs)
|
105 |
+
xr = self.pad1(x) # if pad had weights it would've thrown error with .save_weights()
|
106 |
+
for layer in self.conv:
|
107 |
+
xr = layer(xr)
|
108 |
+
x = x + xr
|
109 |
+
x = self.pad2(x)
|
110 |
+
x = self.reduction(x)
|
111 |
+
x = self.norm2(x)
|
112 |
+
return x
|
113 |
+
|
114 |
+
def get_config(self):
|
115 |
+
config = super().get_config()
|
116 |
+
config.update({
|
117 |
+
"keep_dim":self.keep_dim,
|
118 |
+
})
|
119 |
+
return config
|
120 |
+
|
121 |
+
@tf.keras.utils.register_keras_serializable(package="gcvit")
|
122 |
+
class FeatExtract(tf.keras.layers.Layer):
|
123 |
+
def __init__(self, keep_dim=False, **kwargs):
|
124 |
+
super().__init__(**kwargs)
|
125 |
+
self.keep_dim = keep_dim
|
126 |
+
|
127 |
+
def build(self, input_shape):
|
128 |
+
dim = input_shape[-1]
|
129 |
+
self.pad1 = tf.keras.layers.ZeroPadding2D(1, name='pad1')
|
130 |
+
self.pad2 = tf.keras.layers.ZeroPadding2D(1, name='pad2')
|
131 |
+
self.conv = [
|
132 |
+
tf.keras.layers.DepthwiseConv2D(kernel_size=3, strides=1, padding='valid', use_bias=False, name='conv/0'),
|
133 |
+
tf.keras.layers.Activation('gelu', name='conv/1'),
|
134 |
+
SE(name='conv/2'),
|
135 |
+
tf.keras.layers.Conv2D(dim, kernel_size=1, strides=1, padding='valid', use_bias=False, name='conv/3')
|
136 |
+
]
|
137 |
+
if not self.keep_dim:
|
138 |
+
self.pool = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding='valid', name='pool')
|
139 |
+
# else:
|
140 |
+
# self.pool = tf.keras.layers.Activation('linear', name='identity') # hack for PyTorch nn.Identity layer ;)
|
141 |
+
super().build(input_shape)
|
142 |
+
|
143 |
+
def call(self, inputs, **kwargs):
|
144 |
+
x = inputs
|
145 |
+
xr = self.pad1(x)
|
146 |
+
for layer in self.conv:
|
147 |
+
xr = layer(xr)
|
148 |
+
x = x + xr # if pad had weights it would've thrown error with .save_weights()
|
149 |
+
if not self.keep_dim:
|
150 |
+
x = self.pad2(x)
|
151 |
+
x = self.pool(x)
|
152 |
+
return x
|
153 |
+
|
154 |
+
def get_config(self):
|
155 |
+
config = super().get_config()
|
156 |
+
config.update({
|
157 |
+
"keep_dim":self.keep_dim,
|
158 |
+
})
|
159 |
+
return config
|
160 |
+
|
161 |
+
@tf.keras.utils.register_keras_serializable(package="gcvit")
|
162 |
+
class GlobalQueryGen(tf.keras.layers.Layer):
|
163 |
+
"""
|
164 |
+
Global query generator based on: "Hatamizadeh et al.,
|
165 |
+
Global Context Vision Transformers <https://arxiv.org/abs/2206.09959>"
|
166 |
+
"""
|
167 |
+
def __init__(self, keep_dims=False, **kwargs):
|
168 |
+
super().__init__(**kwargs)
|
169 |
+
self.keep_dims = keep_dims
|
170 |
+
|
171 |
+
def build(self, input_shape):
|
172 |
+
self.to_q_global = [FeatExtract(keep_dim, name=f'to_q_global/{i}') \
|
173 |
+
for i, keep_dim in enumerate(self.keep_dims)]
|
174 |
+
super().build(input_shape)
|
175 |
+
|
176 |
+
def call(self, inputs, **kwargs):
|
177 |
+
x = inputs
|
178 |
+
for layer in self.to_q_global:
|
179 |
+
x = layer(x)
|
180 |
+
return x
|
181 |
+
|
182 |
+
def get_config(self):
|
183 |
+
config = super().get_config()
|
184 |
+
config.update({
|
185 |
+
"keep_dims":self.keep_dims,
|
186 |
+
})
|
187 |
+
return config
|
188 |
+
|
189 |
+
@tf.keras.utils.register_keras_serializable(package="gcvit")
|
190 |
+
class Resizing(tf.keras.layers.Layer):
|
191 |
+
def __init__(self,
|
192 |
+
height,
|
193 |
+
width,
|
194 |
+
interpolation='bilinear',
|
195 |
+
**kwargs):
|
196 |
+
self.height = height
|
197 |
+
self.width = width
|
198 |
+
self.interpolation = interpolation
|
199 |
+
super().__init__(**kwargs)
|
200 |
+
|
201 |
+
def call(self, inputs):
|
202 |
+
# tf.image.resize will always output float32 and operate more efficiently on
|
203 |
+
# float32 unless interpolation is nearest, in which case ouput type matches
|
204 |
+
# input type.
|
205 |
+
if self.interpolation == 'nearest':
|
206 |
+
input_dtype = self.compute_dtype
|
207 |
+
else:
|
208 |
+
input_dtype = tf.float32
|
209 |
+
inputs = tf.cast(inputs, dtype=input_dtype)
|
210 |
+
size = [self.height, self.width]
|
211 |
+
outputs = tf.image.resize(
|
212 |
+
inputs,
|
213 |
+
size=size,
|
214 |
+
method=self.interpolation)
|
215 |
+
return tf.cast(outputs, self.compute_dtype)
|
216 |
+
|
217 |
+
def compute_output_shape(self, input_shape):
|
218 |
+
input_shape = tf.TensorShape(input_shape).as_list()
|
219 |
+
input_shape[H_AXIS] = self.height
|
220 |
+
input_shape[W_AXIS] = self.width
|
221 |
+
return tf.TensorShape(input_shape)
|
222 |
+
|
223 |
+
def get_config(self):
|
224 |
+
config = super().get_config()
|
225 |
+
config.update({
|
226 |
+
'height': self.height,
|
227 |
+
'width': self.width,
|
228 |
+
'interpolation': self.interpolation,
|
229 |
+
})
|
230 |
+
return config
|
231 |
+
|
232 |
+
@tf.keras.utils.register_keras_serializable(package="gcvit")
|
233 |
+
class FitWindow(tf.keras.layers.Layer):
|
234 |
+
"Pad feature to fit window"
|
235 |
+
def __init__(self, window_size, **kwargs):
|
236 |
+
super().__init__(**kwargs)
|
237 |
+
self.window_size = window_size
|
238 |
+
|
239 |
+
def call(self, inputs):
|
240 |
+
B, H, W, C = tf.unstack(tf.shape(inputs), num=4)
|
241 |
+
# pad to multiple of window_size
|
242 |
+
h_pad = (self.window_size - H % self.window_size) % self.window_size
|
243 |
+
w_pad = (self.window_size - W % self.window_size) % self.window_size
|
244 |
+
x = tf.pad(inputs, [[0, 0],
|
245 |
+
[h_pad//2, (h_pad//2 + h_pad%2)], # padding in both directions unlike tfgcvit
|
246 |
+
[w_pad//2, (w_pad//2 + w_pad%2)],
|
247 |
+
[0, 0]])
|
248 |
+
return x
|
249 |
+
|
250 |
+
def get_config(self):
|
251 |
+
config = super().get_config()
|
252 |
+
config.update({
|
253 |
+
'window_size': self.window_size,
|
254 |
+
})
|
255 |
return config
|
gcvit/layers/level.py
CHANGED
@@ -1,93 +1,85 @@
|
|
1 |
-
import tensorflow as tf
|
2 |
-
|
3 |
-
from .feature import
|
4 |
-
from .block import GCViTBlock
|
5 |
-
|
6 |
-
@tf.keras.utils.register_keras_serializable(package="gcvit")
|
7 |
-
class
|
8 |
-
def __init__(self, depth, num_heads, window_size, keep_dims, downsample=True, mlp_ratio=4., qkv_bias=True,
|
9 |
-
qk_scale=None, drop=0., attn_drop=0., path_drop=0., layer_scale=None, resize_query=False, **kwargs):
|
10 |
-
super().__init__(**kwargs)
|
11 |
-
self.depth = depth
|
12 |
-
self.num_heads = num_heads
|
13 |
-
self.window_size = window_size
|
14 |
-
self.keep_dims = keep_dims
|
15 |
-
self.downsample = downsample
|
16 |
-
self.mlp_ratio = mlp_ratio
|
17 |
-
self.qkv_bias = qkv_bias
|
18 |
-
self.qk_scale = qk_scale
|
19 |
-
self.drop = drop
|
20 |
-
self.attn_drop = attn_drop
|
21 |
-
self.path_drop = path_drop
|
22 |
-
self.layer_scale = layer_scale
|
23 |
-
self.resize_query = resize_query
|
24 |
-
|
25 |
-
def build(self, input_shape):
|
26 |
-
path_drop = [self.path_drop] * self.depth if not isinstance(self.path_drop, list) else self.path_drop
|
27 |
-
self.blocks = [
|
28 |
-
GCViTBlock(window_size=self.window_size,
|
29 |
-
num_heads=self.num_heads,
|
30 |
-
global_query=bool(i % 2),
|
31 |
-
mlp_ratio=self.mlp_ratio,
|
32 |
-
qkv_bias=self.qkv_bias,
|
33 |
-
qk_scale=self.qk_scale,
|
34 |
-
drop=self.drop,
|
35 |
-
attn_drop=self.attn_drop,
|
36 |
-
path_drop=path_drop[i],
|
37 |
-
layer_scale=self.layer_scale,
|
38 |
-
name=f'blocks/{i}')
|
39 |
-
for i in range(self.depth)]
|
40 |
-
self.down = ReduceSize(keep_dim=False, name='downsample')
|
41 |
-
self.
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
#
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
'
|
81 |
-
'
|
82 |
-
'
|
83 |
-
'
|
84 |
-
|
85 |
-
'mlp_ratio': self.mlp_ratio,
|
86 |
-
'qkv_bias': self.qkv_bias,
|
87 |
-
'qk_scale': self.qk_scale,
|
88 |
-
'drop': self.drop,
|
89 |
-
'attn_drop': self.attn_drop,
|
90 |
-
'path_drop': self.path_drop,
|
91 |
-
'layer_scale': self.layer_scale
|
92 |
-
})
|
93 |
return config
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
|
3 |
+
from .feature import GlobalQueryGen, ReduceSize, Resizing, FitWindow
|
4 |
+
from .block import GCViTBlock
|
5 |
+
|
6 |
+
@tf.keras.utils.register_keras_serializable(package="gcvit")
|
7 |
+
class GCViTLevel(tf.keras.layers.Layer):
|
8 |
+
def __init__(self, depth, num_heads, window_size, keep_dims, downsample=True, mlp_ratio=4., qkv_bias=True,
|
9 |
+
qk_scale=None, drop=0., attn_drop=0., path_drop=0., layer_scale=None, resize_query=False, **kwargs):
|
10 |
+
super().__init__(**kwargs)
|
11 |
+
self.depth = depth
|
12 |
+
self.num_heads = num_heads
|
13 |
+
self.window_size = window_size
|
14 |
+
self.keep_dims = keep_dims
|
15 |
+
self.downsample = downsample
|
16 |
+
self.mlp_ratio = mlp_ratio
|
17 |
+
self.qkv_bias = qkv_bias
|
18 |
+
self.qk_scale = qk_scale
|
19 |
+
self.drop = drop
|
20 |
+
self.attn_drop = attn_drop
|
21 |
+
self.path_drop = path_drop
|
22 |
+
self.layer_scale = layer_scale
|
23 |
+
self.resize_query = resize_query
|
24 |
+
|
25 |
+
def build(self, input_shape):
|
26 |
+
path_drop = [self.path_drop] * self.depth if not isinstance(self.path_drop, list) else self.path_drop
|
27 |
+
self.blocks = [
|
28 |
+
GCViTBlock(window_size=self.window_size,
|
29 |
+
num_heads=self.num_heads,
|
30 |
+
global_query=bool(i % 2),
|
31 |
+
mlp_ratio=self.mlp_ratio,
|
32 |
+
qkv_bias=self.qkv_bias,
|
33 |
+
qk_scale=self.qk_scale,
|
34 |
+
drop=self.drop,
|
35 |
+
attn_drop=self.attn_drop,
|
36 |
+
path_drop=path_drop[i],
|
37 |
+
layer_scale=self.layer_scale,
|
38 |
+
name=f'blocks/{i}')
|
39 |
+
for i in range(self.depth)]
|
40 |
+
self.down = ReduceSize(keep_dim=False, name='downsample')
|
41 |
+
self.q_global_gen = GlobalQueryGen(self.keep_dims, name='q_global_gen')
|
42 |
+
self.resize = Resizing(self.window_size, self.window_size, interpolation='bicubic')
|
43 |
+
self.fit_window = FitWindow(self.window_size)
|
44 |
+
super().build(input_shape)
|
45 |
+
|
46 |
+
def call(self, inputs, **kwargs):
|
47 |
+
H, W = tf.unstack(tf.shape(inputs)[1:3], num=2)
|
48 |
+
# pad to fit window_size
|
49 |
+
x = self.fit_window(inputs)
|
50 |
+
# generate global query
|
51 |
+
q_global = self.q_global_gen(x) # (B, H, W, C) # official impl issue: https://github.com/NVlabs/GCVit/issues/13
|
52 |
+
# resize query to fit key-value, but result in poor score with official weights?
|
53 |
+
if self.resize_query:
|
54 |
+
q_global = self.resize(q_global) # to avoid mismatch between feat_map and q_global: https://github.com/NVlabs/GCVit/issues/9
|
55 |
+
# feature_map -> windows -> window_attention -> feature_map
|
56 |
+
for i, blk in enumerate(self.blocks):
|
57 |
+
if i % 2:
|
58 |
+
x = blk([x, q_global])
|
59 |
+
else:
|
60 |
+
x = blk([x])
|
61 |
+
x = x[:, :H, :W, :] # https://github.com/NVlabs/GCVit/issues/9
|
62 |
+
# set shape for [B, ?, ?, C]
|
63 |
+
x.set_shape(inputs.shape) # `tf.reshape` creates new tensor with new_shape
|
64 |
+
# downsample
|
65 |
+
if self.downsample:
|
66 |
+
x = self.down(x)
|
67 |
+
return x
|
68 |
+
|
69 |
+
def get_config(self):
|
70 |
+
config = super().get_config()
|
71 |
+
config.update({
|
72 |
+
'depth': self.depth,
|
73 |
+
'num_heads': self.num_heads,
|
74 |
+
'window_size': self.window_size,
|
75 |
+
'keep_dims': self.keep_dims,
|
76 |
+
'downsample': self.downsample,
|
77 |
+
'mlp_ratio': self.mlp_ratio,
|
78 |
+
'qkv_bias': self.qkv_bias,
|
79 |
+
'qk_scale': self.qk_scale,
|
80 |
+
'drop': self.drop,
|
81 |
+
'attn_drop': self.attn_drop,
|
82 |
+
'path_drop': self.path_drop,
|
83 |
+
'layer_scale': self.layer_scale
|
84 |
+
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
return config
|
gcvit/models/__init__.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
from .gcvit import GCViT, GCViTTiny, GCViTSmall, GCViTBase
|
|
|
1 |
+
from .gcvit import GCViT, GCViTXXTiny, GCViTXTiny, GCViTTiny, GCViTSmall, GCViTBase
|
gcvit/models/__pycache__/__init__.cpython-38.pyc
DELETED
Binary file (234 Bytes)
|
|
gcvit/models/__pycache__/gcvit.cpython-38.pyc
DELETED
Binary file (4.08 kB)
|
|
gcvit/models/gcvit.py
CHANGED
@@ -1,145 +1,180 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import tensorflow as tf
|
3 |
-
|
4 |
-
from ..layers import
|
5 |
-
|
6 |
-
|
7 |
-
BASE_URL = 'https://github.com/awsaf49/gcvit-tf/releases/download'
|
8 |
-
TAG = 'v1.0.
|
9 |
-
NAME2CONFIG = {
|
10 |
-
'
|
11 |
-
'dim': 64,
|
12 |
-
'depths': (
|
13 |
-
'num_heads': (2, 4, 8, 16),
|
14 |
-
'
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
'
|
23 |
-
'dim':
|
24 |
-
'depths': (3, 4, 19, 5),
|
25 |
-
'num_heads': (4, 8, 16
|
26 |
-
'mlp_ratio':
|
27 |
-
'path_drop': 0.
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
self.
|
51 |
-
self.
|
52 |
-
|
53 |
-
self.
|
54 |
-
self.
|
55 |
-
|
56 |
-
|
57 |
-
self.
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
self.
|
79 |
-
if global_pool
|
80 |
-
self.
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
return
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import tensorflow as tf
|
3 |
+
|
4 |
+
from ..layers import Stem, GCViTLevel, Identity
|
5 |
+
|
6 |
+
|
7 |
+
BASE_URL = 'https://github.com/awsaf49/gcvit-tf/releases/download'
|
8 |
+
TAG = 'v1.0.9'
|
9 |
+
NAME2CONFIG = {
|
10 |
+
'gcvit_xxtiny': {'window_size': (7, 7, 14, 7),
|
11 |
+
'dim': 64,
|
12 |
+
'depths': (2, 2, 6, 2),
|
13 |
+
'num_heads': (2, 4, 8, 16),
|
14 |
+
'mlp_ratio': 3.,
|
15 |
+
'path_drop': 0.2},
|
16 |
+
'gcvit_xtiny': {'window_size': (7, 7, 14, 7),
|
17 |
+
'dim': 64,
|
18 |
+
'depths': (3, 4, 6, 5),
|
19 |
+
'num_heads': (2, 4, 8, 16),
|
20 |
+
'mlp_ratio': 3.,
|
21 |
+
'path_drop': 0.2},
|
22 |
+
'gcvit_tiny': {'window_size': (7, 7, 14, 7),
|
23 |
+
'dim': 64,
|
24 |
+
'depths': (3, 4, 19, 5),
|
25 |
+
'num_heads': (2, 4, 8, 16),
|
26 |
+
'mlp_ratio': 3.,
|
27 |
+
'path_drop': 0.2,},
|
28 |
+
'gcvit_small': {'window_size': (7, 7, 14, 7),
|
29 |
+
'dim': 96,
|
30 |
+
'depths': (3, 4, 19, 5),
|
31 |
+
'num_heads': (3, 6, 12, 24),
|
32 |
+
'mlp_ratio': 2.,
|
33 |
+
'path_drop': 0.3,
|
34 |
+
'layer_scale': 1e-5,},
|
35 |
+
'gcvit_base': {'window_size': (7, 7, 14, 7),
|
36 |
+
'dim':128,
|
37 |
+
'depths': (3, 4, 19, 5),
|
38 |
+
'num_heads': (4, 8, 16, 32),
|
39 |
+
'mlp_ratio': 2.,
|
40 |
+
'path_drop': 0.5,
|
41 |
+
'layer_scale': 1e-5,},
|
42 |
+
}
|
43 |
+
|
44 |
+
@tf.keras.utils.register_keras_serializable(package='gcvit')
|
45 |
+
class GCViT(tf.keras.Model):
|
46 |
+
def __init__(self, window_size, dim, depths, num_heads,
|
47 |
+
drop_rate=0., mlp_ratio=3., qkv_bias=True, qk_scale=None, attn_drop=0., path_drop=0.1, layer_scale=None, resize_query=False,
|
48 |
+
global_pool='avg', num_classes=1000, head_act='softmax', **kwargs):
|
49 |
+
super().__init__(**kwargs)
|
50 |
+
self.window_size = window_size
|
51 |
+
self.dim = dim
|
52 |
+
self.depths = depths
|
53 |
+
self.num_heads = num_heads
|
54 |
+
self.drop_rate = drop_rate
|
55 |
+
self.mlp_ratio = mlp_ratio
|
56 |
+
self.qkv_bias = qkv_bias
|
57 |
+
self.qk_scale = qk_scale
|
58 |
+
self.attn_drop = attn_drop
|
59 |
+
self.path_drop = path_drop
|
60 |
+
self.layer_scale = layer_scale
|
61 |
+
self.resize_query = resize_query
|
62 |
+
self.global_pool = global_pool
|
63 |
+
self.num_classes = num_classes
|
64 |
+
self.head_act = head_act
|
65 |
+
|
66 |
+
self.patch_embed = Stem(dim=dim, name='patch_embed')
|
67 |
+
self.pos_drop = tf.keras.layers.Dropout(drop_rate, name='pos_drop')
|
68 |
+
path_drops = np.linspace(0., path_drop, sum(depths))
|
69 |
+
keep_dims = [(False, False, False),(False, False),(True,),(True,),]
|
70 |
+
self.levels = []
|
71 |
+
for i in range(len(depths)):
|
72 |
+
path_drop = path_drops[sum(depths[:i]):sum(depths[:i + 1])].tolist()
|
73 |
+
level = GCViTLevel(depth=depths[i], num_heads=num_heads[i], window_size=window_size[i], keep_dims=keep_dims[i],
|
74 |
+
downsample=(i < len(depths) - 1), mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
75 |
+
drop=drop_rate, attn_drop=attn_drop, path_drop=path_drop, layer_scale=layer_scale, resize_query=resize_query,
|
76 |
+
name=f'levels/{i}')
|
77 |
+
self.levels.append(level)
|
78 |
+
self.norm = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-05, name='norm')
|
79 |
+
if global_pool == 'avg':
|
80 |
+
self.pool = tf.keras.layers.GlobalAveragePooling2D(name='pool')
|
81 |
+
elif global_pool == 'max':
|
82 |
+
self.pool = tf.keras.layers.GlobalMaxPooling2D(name='pool')
|
83 |
+
elif global_pool is None:
|
84 |
+
self.pool = Identity(name='pool')
|
85 |
+
else:
|
86 |
+
raise ValueError(f'Expecting pooling to be one of None/avg/max. Found: {global_pool}')
|
87 |
+
self.head = tf.keras.layers.Dense(num_classes, name='head', activation=head_act)
|
88 |
+
|
89 |
+
def reset_classifier(self, num_classes, head_act, global_pool=None, in_channels=3):
|
90 |
+
self.num_classes = num_classes
|
91 |
+
if global_pool is not None:
|
92 |
+
self.global_pool = global_pool
|
93 |
+
self.head = tf.keras.layers.Dense(num_classes, name='head', activation=head_act) if num_classes else Identity(name='head')
|
94 |
+
super().build((1, 224, 224, in_channels)) # for head we only need info from the input channel
|
95 |
+
|
96 |
+
def forward_features(self, inputs):
|
97 |
+
x = self.patch_embed(inputs)
|
98 |
+
x = self.pos_drop(x)
|
99 |
+
x = tf.cast(x, dtype=tf.float32)
|
100 |
+
for level in self.levels:
|
101 |
+
x = level(x)
|
102 |
+
x = self.norm(x)
|
103 |
+
return x
|
104 |
+
|
105 |
+
def forward_head(self, inputs, pre_logits=False):
|
106 |
+
x = inputs
|
107 |
+
if self.global_pool in ['avg', 'max']:
|
108 |
+
x = self.pool(x)
|
109 |
+
if not pre_logits:
|
110 |
+
x = self.head(x)
|
111 |
+
return x
|
112 |
+
|
113 |
+
def call(self, inputs, **kwargs):
|
114 |
+
x = self.forward_features(inputs)
|
115 |
+
x = self.forward_head(x)
|
116 |
+
return x
|
117 |
+
|
118 |
+
def build_graph(self, input_shape=(224, 224, 3)):
|
119 |
+
"""https://www.kaggle.com/code/ipythonx/tf-hybrid-efficientnet-swin-transformer-gradcam"""
|
120 |
+
x = tf.keras.Input(shape=input_shape)
|
121 |
+
return tf.keras.Model(inputs=[x], outputs=self.call(x), name=self.name)
|
122 |
+
|
123 |
+
def summary(self, input_shape=(224, 224, 3)):
|
124 |
+
return self.build_graph(input_shape).summary()
|
125 |
+
|
126 |
+
# load standard models
|
127 |
+
def GCViTXXTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
|
128 |
+
name = 'gcvit_xxtiny'
|
129 |
+
config = NAME2CONFIG[name]
|
130 |
+
ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
|
131 |
+
model = GCViT(name=name, resize_query=resize_query, **config, **kwargs)
|
132 |
+
model(tf.random.uniform(shape=input_shape)[tf.newaxis,])
|
133 |
+
if pretrain:
|
134 |
+
ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
|
135 |
+
model.load_weights(ckpt_path)
|
136 |
+
return model
|
137 |
+
|
138 |
+
def GCViTXTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
|
139 |
+
name = 'gcvit_xtiny'
|
140 |
+
config = NAME2CONFIG[name]
|
141 |
+
ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
|
142 |
+
model = GCViT(name=name, resize_query=resize_query, **config, **kwargs)
|
143 |
+
model(tf.random.uniform(shape=input_shape)[tf.newaxis,])
|
144 |
+
if pretrain:
|
145 |
+
ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
|
146 |
+
model.load_weights(ckpt_path)
|
147 |
+
return model
|
148 |
+
|
149 |
+
def GCViTTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
|
150 |
+
name = 'gcvit_tiny'
|
151 |
+
config = NAME2CONFIG[name]
|
152 |
+
ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
|
153 |
+
model = GCViT(name=name, resize_query=resize_query, **config, **kwargs)
|
154 |
+
model(tf.random.uniform(shape=input_shape)[tf.newaxis,])
|
155 |
+
if pretrain:
|
156 |
+
ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
|
157 |
+
model.load_weights(ckpt_path)
|
158 |
+
return model
|
159 |
+
|
160 |
+
def GCViTSmall(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
|
161 |
+
name = 'gcvit_small'
|
162 |
+
config = NAME2CONFIG[name]
|
163 |
+
ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
|
164 |
+
model = GCViT(name=name, resize_query=resize_query, **config, **kwargs)
|
165 |
+
model(tf.random.uniform(shape=input_shape)[tf.newaxis,])
|
166 |
+
if pretrain:
|
167 |
+
ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
|
168 |
+
model.load_weights(ckpt_path)
|
169 |
+
return model
|
170 |
+
|
171 |
+
def GCViTBase(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
|
172 |
+
name = 'gcvit_base'
|
173 |
+
config = NAME2CONFIG[name]
|
174 |
+
ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
|
175 |
+
model = GCViT(name=name, resize_query=resize_query, **config, **kwargs)
|
176 |
+
model(tf.random.uniform(shape=input_shape)[tf.newaxis,])
|
177 |
+
if pretrain:
|
178 |
+
ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
|
179 |
+
model.load_weights(ckpt_path)
|
180 |
+
return model
|
gcvit/utils/gradcam.py
CHANGED
@@ -1,69 +1,69 @@
|
|
1 |
-
import tensorflow as tf
|
2 |
-
import matplotlib.cm as cm
|
3 |
-
import numpy as np
|
4 |
-
try:
|
5 |
-
from tensorflow.keras.utils import array_to_img, img_to_array
|
6 |
-
except:
|
7 |
-
from tensorflow.keras.preprocessing.image import array_to_img, img_to_array
|
8 |
-
|
9 |
-
def process_image(img, size=(224, 224)):
|
10 |
-
img_array = tf.keras.applications.imagenet_utils.preprocess_input(img, mode='torch')
|
11 |
-
img_array = tf.image.resize(img_array, size,)[None,]
|
12 |
-
return img_array
|
13 |
-
|
14 |
-
def get_gradcam_model(model):
|
15 |
-
inp = tf.keras.Input(shape=(224, 224, 3))
|
16 |
-
feats = model.forward_features(inp)
|
17 |
-
preds = model.forward_head(feats)
|
18 |
-
return tf.keras.models.Model(inp, [preds, feats])
|
19 |
-
|
20 |
-
def get_gradcam_prediction(img, grad_model, process=True, decode=True, pred_index=None, cmap='jet', alpha=0.4):
|
21 |
-
"""Grad-CAM for a single image
|
22 |
-
|
23 |
-
Args:
|
24 |
-
img (np.ndarray): process or raw image without batch_shape e.g. (224, 224, 3)
|
25 |
-
grad_model (tf.keras.Model): model with feature map and prediction
|
26 |
-
process (bool, optional): imagenet pre-processing. Defaults to True.
|
27 |
-
pred_index (int, optional): for particular calss. Defaults to None.
|
28 |
-
cmap (str, optional): colormap. Defaults to 'jet'.
|
29 |
-
alpha (float, optional): opacity. Defaults to 0.4.
|
30 |
-
|
31 |
-
Returns:
|
32 |
-
preds_decode: top5 predictions
|
33 |
-
heatmap: gradcam heatmap
|
34 |
-
"""
|
35 |
-
# process image for inference
|
36 |
-
if process:
|
37 |
-
img_array = process_image(img)
|
38 |
-
else:
|
39 |
-
img_array = tf.convert_to_tensor(img)[None,]
|
40 |
-
if img.min()!=img.max():
|
41 |
-
img = (img - img.min())/(img.max() - img.min())
|
42 |
-
img = np.uint8(img*255.0)
|
43 |
-
# get prediction
|
44 |
-
with tf.GradientTape(persistent=True) as tape:
|
45 |
-
preds, feats = grad_model(img_array)
|
46 |
-
if pred_index is None:
|
47 |
-
pred_index = tf.argmax(preds[0])
|
48 |
-
class_channel = preds[:, pred_index]
|
49 |
-
# compute heatmap
|
50 |
-
grads = tape.gradient(class_channel, feats)
|
51 |
-
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
|
52 |
-
feats = feats[0]
|
53 |
-
heatmap = feats @ pooled_grads[..., tf.newaxis]
|
54 |
-
heatmap = tf.squeeze(heatmap)
|
55 |
-
heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
|
56 |
-
heatmap = heatmap.numpy()
|
57 |
-
heatmap = np.uint8(255 * heatmap)
|
58 |
-
# colorize heatmap
|
59 |
-
cmap = cm.get_cmap(cmap)
|
60 |
-
colors = cmap(np.arange(256))[:, :3]
|
61 |
-
heatmap = colors[heatmap]
|
62 |
-
heatmap = array_to_img(heatmap)
|
63 |
-
heatmap = heatmap.resize((img.shape[1], img.shape[0]))
|
64 |
-
heatmap = img_to_array(heatmap)
|
65 |
-
overlay = img + heatmap * alpha
|
66 |
-
overlay = array_to_img(overlay)
|
67 |
-
# decode prediction
|
68 |
-
preds_decode = tf.keras.applications.imagenet_utils.decode_predictions(preds.numpy())[0] if decode else preds
|
69 |
return preds_decode, overlay
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import matplotlib.cm as cm
|
3 |
+
import numpy as np
|
4 |
+
try:
|
5 |
+
from tensorflow.keras.utils import array_to_img, img_to_array
|
6 |
+
except:
|
7 |
+
from tensorflow.keras.preprocessing.image import array_to_img, img_to_array
|
8 |
+
|
9 |
+
def process_image(img, size=(224, 224)):
|
10 |
+
img_array = tf.keras.applications.imagenet_utils.preprocess_input(img, mode='torch')
|
11 |
+
img_array = tf.image.resize(img_array, size,)[None,]
|
12 |
+
return img_array
|
13 |
+
|
14 |
+
def get_gradcam_model(model):
|
15 |
+
inp = tf.keras.Input(shape=(224, 224, 3))
|
16 |
+
feats = model.forward_features(inp)
|
17 |
+
preds = model.forward_head(feats)
|
18 |
+
return tf.keras.models.Model(inp, [preds, feats])
|
19 |
+
|
20 |
+
def get_gradcam_prediction(img, grad_model, process=True, decode=True, pred_index=None, cmap='jet', alpha=0.4):
|
21 |
+
"""Grad-CAM for a single image
|
22 |
+
|
23 |
+
Args:
|
24 |
+
img (np.ndarray): process or raw image without batch_shape e.g. (224, 224, 3)
|
25 |
+
grad_model (tf.keras.Model): model with feature map and prediction
|
26 |
+
process (bool, optional): imagenet pre-processing. Defaults to True.
|
27 |
+
pred_index (int, optional): for particular calss. Defaults to None.
|
28 |
+
cmap (str, optional): colormap. Defaults to 'jet'.
|
29 |
+
alpha (float, optional): opacity. Defaults to 0.4.
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
preds_decode: top5 predictions
|
33 |
+
heatmap: gradcam heatmap
|
34 |
+
"""
|
35 |
+
# process image for inference
|
36 |
+
if process:
|
37 |
+
img_array = process_image(img)
|
38 |
+
else:
|
39 |
+
img_array = tf.convert_to_tensor(img)[None,]
|
40 |
+
if img.min()!=img.max():
|
41 |
+
img = (img - img.min())/(img.max() - img.min())
|
42 |
+
img = np.uint8(img*255.0)
|
43 |
+
# get prediction
|
44 |
+
with tf.GradientTape(persistent=True) as tape:
|
45 |
+
preds, feats = grad_model(img_array)
|
46 |
+
if pred_index is None:
|
47 |
+
pred_index = tf.argmax(preds[0])
|
48 |
+
class_channel = preds[:, pred_index]
|
49 |
+
# compute heatmap
|
50 |
+
grads = tape.gradient(class_channel, feats)
|
51 |
+
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
|
52 |
+
feats = feats[0]
|
53 |
+
heatmap = feats @ pooled_grads[..., tf.newaxis]
|
54 |
+
heatmap = tf.squeeze(heatmap)
|
55 |
+
heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
|
56 |
+
heatmap = heatmap.numpy()
|
57 |
+
heatmap = np.uint8(255 * heatmap)
|
58 |
+
# colorize heatmap
|
59 |
+
cmap = cm.get_cmap(cmap)
|
60 |
+
colors = cmap(np.arange(256))[:, :3]
|
61 |
+
heatmap = colors[heatmap]
|
62 |
+
heatmap = array_to_img(heatmap)
|
63 |
+
heatmap = heatmap.resize((img.shape[1], img.shape[0]))
|
64 |
+
heatmap = img_to_array(heatmap)
|
65 |
+
overlay = img + heatmap * alpha
|
66 |
+
overlay = array_to_img(overlay)
|
67 |
+
# decode prediction
|
68 |
+
preds_decode = tf.keras.applications.imagenet_utils.decode_predictions(preds.numpy())[0] if decode else preds
|
69 |
return preds_decode, overlay
|
gcvit/version.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
__version__ = "1.0.
|
|
|
1 |
+
__version__ = "1.0.9"
|
requirements.txt
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
-
tensorflow==2.4.1
|
2 |
-
tensorflow_addons==0.14.0
|
3 |
-
gradio==3.1.0
|
4 |
-
numpy
|
5 |
matplotlib
|
|
|
1 |
+
tensorflow==2.4.1
|
2 |
+
tensorflow_addons==0.14.0
|
3 |
+
gradio==3.1.0
|
4 |
+
numpy
|
5 |
matplotlib
|
setup.py
CHANGED
@@ -1,50 +1,50 @@
|
|
1 |
-
from setuptools import setup, find_packages
|
2 |
-
from codecs import open
|
3 |
-
from os import path
|
4 |
-
|
5 |
-
here = path.abspath(path.dirname(__file__))
|
6 |
-
|
7 |
-
# Get the long description from the README file
|
8 |
-
with open(path.join(here, "README.md"), encoding="utf-8") as f:
|
9 |
-
long_description = f.read()
|
10 |
-
|
11 |
-
with open(path.join(here, 'requirements.txt')) as f:
|
12 |
-
install_requires = [x for x in f.read().splitlines() if len(x)]
|
13 |
-
|
14 |
-
exec(open("gcvit/version.py").read())
|
15 |
-
|
16 |
-
setup(
|
17 |
-
name="gcvit",
|
18 |
-
version=__version__,
|
19 |
-
description="Tensorflow 2.0 Implementation of GCViT: Global Context Vision Transformer. https://github.com/awsaf49/gcvit-tf",
|
20 |
-
long_description=long_description,
|
21 |
-
long_description_content_type="text/markdown",
|
22 |
-
url="https://github.com/awsaf49/gcvit-tf",
|
23 |
-
author="Awsaf",
|
24 |
-
author_email="[email protected]",
|
25 |
-
classifiers=[
|
26 |
-
# How mature is this project? Common values are
|
27 |
-
# 3 - Alpha
|
28 |
-
# 4 - Beta
|
29 |
-
# 5 - Production/Stable
|
30 |
-
"Development Status :: 3 - Alpha",
|
31 |
-
"Intended Audience :: Developers",
|
32 |
-
"Intended Audience :: Science/Research",
|
33 |
-
"License :: OSI Approved :: Apache Software License",
|
34 |
-
"Programming Language :: Python :: 3.6",
|
35 |
-
"Programming Language :: Python :: 3.7",
|
36 |
-
"Programming Language :: Python :: 3.8",
|
37 |
-
"Topic :: Scientific/Engineering",
|
38 |
-
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
39 |
-
"Topic :: Software Development",
|
40 |
-
"Topic :: Software Development :: Libraries",
|
41 |
-
"Topic :: Software Development :: Libraries :: Python Modules",
|
42 |
-
],
|
43 |
-
# Note that this is a string of words separated by whitespace, not a list.
|
44 |
-
keywords="tensorflow computer_vision image classification transformer",
|
45 |
-
packages=find_packages(exclude=["tests"]),
|
46 |
-
include_package_data=True,
|
47 |
-
install_requires=install_requires,
|
48 |
-
python_requires=">=3.6",
|
49 |
-
license="MIT",
|
50 |
)
|
|
|
1 |
+
from setuptools import setup, find_packages
|
2 |
+
from codecs import open
|
3 |
+
from os import path
|
4 |
+
|
5 |
+
here = path.abspath(path.dirname(__file__))
|
6 |
+
|
7 |
+
# Get the long description from the README file
|
8 |
+
with open(path.join(here, "README.md"), encoding="utf-8") as f:
|
9 |
+
long_description = f.read()
|
10 |
+
|
11 |
+
with open(path.join(here, 'requirements.txt')) as f:
|
12 |
+
install_requires = [x for x in f.read().splitlines() if len(x)]
|
13 |
+
|
14 |
+
exec(open("gcvit/version.py").read())
|
15 |
+
|
16 |
+
setup(
|
17 |
+
name="gcvit",
|
18 |
+
version=__version__,
|
19 |
+
description="Tensorflow 2.0 Implementation of GCViT: Global Context Vision Transformer. https://github.com/awsaf49/gcvit-tf",
|
20 |
+
long_description=long_description,
|
21 |
+
long_description_content_type="text/markdown",
|
22 |
+
url="https://github.com/awsaf49/gcvit-tf",
|
23 |
+
author="Awsaf",
|
24 |
+
author_email="[email protected]",
|
25 |
+
classifiers=[
|
26 |
+
# How mature is this project? Common values are
|
27 |
+
# 3 - Alpha
|
28 |
+
# 4 - Beta
|
29 |
+
# 5 - Production/Stable
|
30 |
+
"Development Status :: 3 - Alpha",
|
31 |
+
"Intended Audience :: Developers",
|
32 |
+
"Intended Audience :: Science/Research",
|
33 |
+
"License :: OSI Approved :: Apache Software License",
|
34 |
+
"Programming Language :: Python :: 3.6",
|
35 |
+
"Programming Language :: Python :: 3.7",
|
36 |
+
"Programming Language :: Python :: 3.8",
|
37 |
+
"Topic :: Scientific/Engineering",
|
38 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
39 |
+
"Topic :: Software Development",
|
40 |
+
"Topic :: Software Development :: Libraries",
|
41 |
+
"Topic :: Software Development :: Libraries :: Python Modules",
|
42 |
+
],
|
43 |
+
# Note that this is a string of words separated by whitespace, not a list.
|
44 |
+
keywords="tensorflow computer_vision image classification transformer",
|
45 |
+
packages=find_packages(exclude=["tests"]),
|
46 |
+
include_package_data=True,
|
47 |
+
install_requires=install_requires,
|
48 |
+
python_requires=">=3.6",
|
49 |
+
license="MIT",
|
50 |
)
|