awsaf49 commited on
Commit
4a0cabe
1 Parent(s): 69a7cee

lastest version

Browse files
.gitattributes CHANGED
@@ -1,31 +1,31 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ftz filter=lfs diff=lfs merge=lfs -text
6
- *.gz filter=lfs diff=lfs merge=lfs -text
7
- *.h5 filter=lfs diff=lfs merge=lfs -text
8
- *.joblib filter=lfs diff=lfs merge=lfs -text
9
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
- *.model filter=lfs diff=lfs merge=lfs -text
11
- *.msgpack filter=lfs diff=lfs merge=lfs -text
12
- *.npy filter=lfs diff=lfs merge=lfs -text
13
- *.npz filter=lfs diff=lfs merge=lfs -text
14
- *.onnx filter=lfs diff=lfs merge=lfs -text
15
- *.ot filter=lfs diff=lfs merge=lfs -text
16
- *.parquet filter=lfs diff=lfs merge=lfs -text
17
- *.pickle filter=lfs diff=lfs merge=lfs -text
18
- *.pkl filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pt filter=lfs diff=lfs merge=lfs -text
21
- *.pth filter=lfs diff=lfs merge=lfs -text
22
- *.rar filter=lfs diff=lfs merge=lfs -text
23
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
24
- *.tar.* filter=lfs diff=lfs merge=lfs -text
25
- *.tflite filter=lfs diff=lfs merge=lfs -text
26
- *.tgz filter=lfs diff=lfs merge=lfs -text
27
- *.wasm filter=lfs diff=lfs merge=lfs -text
28
- *.xz filter=lfs diff=lfs merge=lfs -text
29
- *.zip filter=lfs diff=lfs merge=lfs -text
30
- *.zstandard filter=lfs diff=lfs merge=lfs -text
31
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.npy filter=lfs diff=lfs merge=lfs -text
13
+ *.npz filter=lfs diff=lfs merge=lfs -text
14
+ *.onnx filter=lfs diff=lfs merge=lfs -text
15
+ *.ot filter=lfs diff=lfs merge=lfs -text
16
+ *.parquet filter=lfs diff=lfs merge=lfs -text
17
+ *.pickle filter=lfs diff=lfs merge=lfs -text
18
+ *.pkl filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pt filter=lfs diff=lfs merge=lfs -text
21
+ *.pth filter=lfs diff=lfs merge=lfs -text
22
+ *.rar filter=lfs diff=lfs merge=lfs -text
23
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
24
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
25
+ *.tflite filter=lfs diff=lfs merge=lfs -text
26
+ *.tgz filter=lfs diff=lfs merge=lfs -text
27
+ *.wasm filter=lfs diff=lfs merge=lfs -text
28
+ *.xz filter=lfs diff=lfs merge=lfs -text
29
+ *.zip filter=lfs diff=lfs merge=lfs -text
30
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
31
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,13 +1,13 @@
1
- ---
2
- title: Gcvit Tf
3
- emoji: 📈
4
- colorFrom: yellow
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 3.1.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: Gcvit Tf
3
+ emoji: 📈
4
+ colorFrom: yellow
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 3.1.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,33 +1,33 @@
1
- import tensorflow as tf
2
- import gradio as gr
3
- import gcvit
4
- from gcvit.utils import get_gradcam_model, get_gradcam_prediction
5
-
6
- def predict_fn(image, model_name):
7
- """A predict function that will be invoked by gradio."""
8
- model = getattr(gcvit, model_name)(pretrain=True)
9
- gradcam_model = get_gradcam_model(model)
10
- preds, overlay = get_gradcam_prediction(image, gradcam_model, cmap='jet', alpha=0.4, pred_index=None)
11
- preds = {x[1]:float(x[2]) for x in preds}
12
- return [preds, overlay]
13
-
14
- demo = gr.Interface(
15
- fn=predict_fn,
16
- inputs=[
17
- gr.inputs.Image(label="Input Image"),
18
- gr.Radio(['GCViTTiny', 'GCViTSmall', 'GCViTBase'], value='GCViTTiny', label='Model Name')
19
- ],
20
- outputs=[
21
- gr.outputs.Label(label="Prediction"),
22
- gr.inputs.Image(label="GradCAM"),
23
- ],
24
- title="Global Context Vision Transformer (GCViT) Demo",
25
- description="Image Classification with GCViT Model using ImageNet Pretrain Weights.",
26
- examples=[
27
- ["example/hot_air_ballon.jpg", 'GCViTTiny'],
28
- ["example/chelsea.png", 'GCViTTiny'],
29
- ["example/penguin.JPG", 'GCViTTiny'],
30
- ["example/bus.jpg", 'GCViTTiny'],
31
- ],
32
- )
33
  demo.launch()
 
1
+ import tensorflow as tf
2
+ import gradio as gr
3
+ import gcvit
4
+ from gcvit.utils import get_gradcam_model, get_gradcam_prediction
5
+
6
+ def predict_fn(image, model_name):
7
+ """A predict function that will be invoked by gradio."""
8
+ model = getattr(gcvit, model_name)(pretrain=True)
9
+ gradcam_model = get_gradcam_model(model)
10
+ preds, overlay = get_gradcam_prediction(image, gradcam_model, cmap='jet', alpha=0.4, pred_index=None)
11
+ preds = {x[1]:float(x[2]) for x in preds}
12
+ return [preds, overlay]
13
+
14
+ demo = gr.Interface(
15
+ fn=predict_fn,
16
+ inputs=[
17
+ gr.inputs.Image(label="Input Image"),
18
+ gr.Radio(['GCViTTiny', 'GCViTSmall', 'GCViTBase'], value='GCViTTiny', label='Model Name')
19
+ ],
20
+ outputs=[
21
+ gr.outputs.Label(label="Prediction"),
22
+ gr.inputs.Image(label="GradCAM"),
23
+ ],
24
+ title="Global Context Vision Transformer (GCViT) Demo",
25
+ description="Image Classification with GCViT Model using ImageNet Pretrain Weights.",
26
+ examples=[
27
+ ["example/hot_air_ballon.jpg", 'GCViTTiny'],
28
+ ["example/chelsea.png", 'GCViTTiny'],
29
+ ["example/penguin.JPG", 'GCViTTiny'],
30
+ ["example/bus.jpg", 'GCViTTiny'],
31
+ ],
32
+ )
33
  demo.launch()
example/Standing_jaguar.jpg ADDED
gcvit/__init__.py CHANGED
@@ -1,2 +1,2 @@
1
- from .models import GCViT, GCViTTiny, GCViTSmall, GCViTBase
2
  from .version import __version__
 
1
+ from .models import GCViT, GCViTXXTiny, GCViTXTiny, GCViTTiny, GCViTSmall, GCViTBase
2
  from .version import __version__
gcvit/__pycache__/__init__.cpython-38.pyc DELETED
Binary file (228 Bytes)
 
gcvit/layers/__init__.py CHANGED
@@ -1,7 +1,7 @@
1
- from .window import window_partition, window_reverse
2
- from .attention import WindowAttention
3
- from .drop import DropPath, Identity
4
- from .embedding import PatchEmbed
5
- from .feature import Mlp, FeatExtract, ReduceSize, SE, Resizing
6
- from .block import GCViTBlock
7
- from .level import GCViTLayer
 
1
+ from .window import window_partition, window_reverse
2
+ from .attention import WindowAttention
3
+ from .drop import DropPath, Identity
4
+ from .embedding import Stem
5
+ from .feature import Mlp, FeatExtract, ReduceSize, SE, Resizing
6
+ from .block import GCViTBlock
7
+ from .level import GCViTLevel
gcvit/layers/__pycache__/__init__.cpython-38.pyc DELETED
Binary file (530 Bytes)
 
gcvit/layers/__pycache__/attention.cpython-38.pyc DELETED
Binary file (3.58 kB)
 
gcvit/layers/__pycache__/block.cpython-38.pyc DELETED
Binary file (3 kB)
 
gcvit/layers/__pycache__/drop.cpython-38.pyc DELETED
Binary file (1.8 kB)
 
gcvit/layers/__pycache__/embedding.cpython-38.pyc DELETED
Binary file (1.39 kB)
 
gcvit/layers/__pycache__/feature.cpython-38.pyc DELETED
Binary file (5.5 kB)
 
gcvit/layers/__pycache__/level.cpython-38.pyc DELETED
Binary file (3 kB)
 
gcvit/layers/__pycache__/window.cpython-38.pyc DELETED
Binary file (801 Bytes)
 
gcvit/layers/block.py CHANGED
@@ -1,99 +1,99 @@
1
- import tensorflow as tf
2
-
3
- from .attention import WindowAttention
4
- from .drop import DropPath
5
- from .window import window_partition, window_reverse
6
- from .feature import Mlp, FeatExtract
7
-
8
-
9
- @tf.keras.utils.register_keras_serializable(package="gcvit")
10
- class GCViTBlock(tf.keras.layers.Layer):
11
- def __init__(self, window_size, num_heads, global_query, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0.,
12
- attn_drop=0., path_drop=0., act_layer='gelu', layer_scale=None, **kwargs):
13
- super().__init__(**kwargs)
14
- self.window_size = window_size
15
- self.num_heads = num_heads
16
- self.global_query = global_query
17
- self.mlp_ratio = mlp_ratio
18
- self.qkv_bias = qkv_bias
19
- self.qk_scale = qk_scale
20
- self.drop = drop
21
- self.attn_drop = attn_drop
22
- self.path_drop = path_drop
23
- self.act_layer = act_layer
24
- self.layer_scale = layer_scale
25
-
26
- def build(self, input_shape):
27
- B, H, W, C = input_shape[0]
28
- self.norm1 = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-05, name='norm1')
29
- self.attn = WindowAttention(window_size=self.window_size,
30
- num_heads=self.num_heads,
31
- global_query=self.global_query,
32
- qkv_bias=self.qkv_bias,
33
- qk_scale=self.qk_scale,
34
- attn_dropout=self.attn_drop,
35
- proj_dropout=self.drop,
36
- name='attn')
37
- self.drop_path1 = DropPath(self.path_drop)
38
- self.drop_path2 = DropPath(self.path_drop)
39
- self.norm2 = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-05, name='norm2')
40
- self.mlp = Mlp(hidden_features=int(C * self.mlp_ratio), dropout=self.drop, act_layer=self.act_layer, name='mlp')
41
- if self.layer_scale is not None:
42
- self.gamma1 = self.add_weight(
43
- 'gamma1',
44
- shape=[C],
45
- initializer=tf.keras.initializers.Constant(self.layer_scale),
46
- trainable=True,
47
- dtype=self.dtype)
48
- self.gamma2 = self.add_weight(
49
- 'gamma2',
50
- shape=[C],
51
- initializer=tf.keras.initializers.Constant(self.layer_scale),
52
- trainable=True,
53
- dtype=self.dtype)
54
- else:
55
- self.gamma1 = 1.0
56
- self.gamma2 = 1.0
57
- self.num_windows = int(H // self.window_size) * int(W // self.window_size)
58
- super().build(input_shape)
59
-
60
- def call(self, inputs, **kwargs):
61
- if self.global_query:
62
- inputs, q_global = inputs
63
- else:
64
- inputs = inputs[0]
65
- B, H, W, C = tf.unstack(tf.shape(inputs), num=4)
66
- x = self.norm1(inputs)
67
- # create windows and concat them in batch axis
68
- x = window_partition(x, self.window_size) # (B_, win_h, win_w, C)
69
- # flatten patch
70
- x = tf.reshape(x, shape=[-1, self.window_size * self.window_size, C]) # (B_, N, C) => (batch*num_win, num_token, feature)
71
- # attention
72
- if self.global_query:
73
- x = self.attn([x, q_global])
74
- else:
75
- x = self.attn([x])
76
- # reverse window partition
77
- x = window_reverse(x, self.window_size, H, W, C)
78
- # FFN
79
- x = inputs + self.drop_path1(x * self.gamma1)
80
- x = x + self.drop_path2(self.gamma2 * self.mlp(self.norm2(x)))
81
- return x
82
-
83
- def get_config(self):
84
- config = super().get_config()
85
- config.update({
86
- 'window_size': self.window_size,
87
- 'num_heads': self.num_heads,
88
- 'global_query': self.global_query,
89
- 'mlp_ratio': self.mlp_ratio,
90
- 'qkv_bias': self.qkv_bias,
91
- 'qk_scale': self.qk_scale,
92
- 'drop': self.drop,
93
- 'attn_drop': self.attn_drop,
94
- 'path_drop': self.path_drop,
95
- 'act_layer': self.act_layer,
96
- 'layer_scale': self.layer_scale,
97
- 'num_windows': self.num_windows,
98
- })
99
  return config
 
1
+ import tensorflow as tf
2
+
3
+ from .attention import WindowAttention
4
+ from .drop import DropPath
5
+ from .window import window_partition, window_reverse
6
+ from .feature import Mlp, FeatExtract
7
+
8
+
9
+ @tf.keras.utils.register_keras_serializable(package="gcvit")
10
+ class GCViTBlock(tf.keras.layers.Layer):
11
+ def __init__(self, window_size, num_heads, global_query, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0.,
12
+ attn_drop=0., path_drop=0., act_layer='gelu', layer_scale=None, **kwargs):
13
+ super().__init__(**kwargs)
14
+ self.window_size = window_size
15
+ self.num_heads = num_heads
16
+ self.global_query = global_query
17
+ self.mlp_ratio = mlp_ratio
18
+ self.qkv_bias = qkv_bias
19
+ self.qk_scale = qk_scale
20
+ self.drop = drop
21
+ self.attn_drop = attn_drop
22
+ self.path_drop = path_drop
23
+ self.act_layer = act_layer
24
+ self.layer_scale = layer_scale
25
+
26
+ def build(self, input_shape):
27
+ B, H, W, C = input_shape[0]
28
+ self.norm1 = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-05, name='norm1')
29
+ self.attn = WindowAttention(window_size=self.window_size,
30
+ num_heads=self.num_heads,
31
+ global_query=self.global_query,
32
+ qkv_bias=self.qkv_bias,
33
+ qk_scale=self.qk_scale,
34
+ attn_dropout=self.attn_drop,
35
+ proj_dropout=self.drop,
36
+ name='attn')
37
+ self.drop_path1 = DropPath(self.path_drop)
38
+ self.drop_path2 = DropPath(self.path_drop)
39
+ self.norm2 = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-05, name='norm2')
40
+ self.mlp = Mlp(hidden_features=int(C * self.mlp_ratio), dropout=self.drop, act_layer=self.act_layer, name='mlp')
41
+ if self.layer_scale is not None:
42
+ self.gamma1 = self.add_weight(
43
+ 'gamma1',
44
+ shape=[C],
45
+ initializer=tf.keras.initializers.Constant(self.layer_scale),
46
+ trainable=True,
47
+ dtype=self.dtype)
48
+ self.gamma2 = self.add_weight(
49
+ 'gamma2',
50
+ shape=[C],
51
+ initializer=tf.keras.initializers.Constant(self.layer_scale),
52
+ trainable=True,
53
+ dtype=self.dtype)
54
+ else:
55
+ self.gamma1 = 1.0
56
+ self.gamma2 = 1.0
57
+ self.num_windows = int(H // self.window_size) * int(W // self.window_size)
58
+ super().build(input_shape)
59
+
60
+ def call(self, inputs, **kwargs):
61
+ if self.global_query:
62
+ inputs, q_global = inputs
63
+ else:
64
+ inputs = inputs[0]
65
+ B, H, W, C = tf.unstack(tf.shape(inputs), num=4)
66
+ x = self.norm1(inputs)
67
+ # create windows and concat them in batch axis
68
+ x = window_partition(x, self.window_size) # (B_, win_h, win_w, C)
69
+ # flatten patch
70
+ x = tf.reshape(x, shape=[-1, self.window_size * self.window_size, C]) # (B_, N, C) => (batch*num_win, num_token, feature)
71
+ # attention
72
+ if self.global_query:
73
+ x = self.attn([x, q_global])
74
+ else:
75
+ x = self.attn([x])
76
+ # reverse window partition
77
+ x = window_reverse(x, self.window_size, H, W, C)
78
+ # FFN
79
+ x = inputs + self.drop_path1(x * self.gamma1)
80
+ x = x + self.drop_path2(self.gamma2 * self.mlp(self.norm2(x)))
81
+ return x
82
+
83
+ def get_config(self):
84
+ config = super().get_config()
85
+ config.update({
86
+ 'window_size': self.window_size,
87
+ 'num_heads': self.num_heads,
88
+ 'global_query': self.global_query,
89
+ 'mlp_ratio': self.mlp_ratio,
90
+ 'qkv_bias': self.qkv_bias,
91
+ 'qk_scale': self.qk_scale,
92
+ 'drop': self.drop,
93
+ 'attn_drop': self.attn_drop,
94
+ 'path_drop': self.path_drop,
95
+ 'act_layer': self.act_layer,
96
+ 'layer_scale': self.layer_scale,
97
+ 'num_windows': self.num_windows,
98
+ })
99
  return config
gcvit/layers/embedding.py CHANGED
@@ -4,7 +4,7 @@ from .feature import ReduceSize
4
 
5
 
6
  @tf.keras.utils.register_keras_serializable(package="gcvit")
7
- class PatchEmbed(tf.keras.layers.Layer):
8
  def __init__(self, dim, **kwargs):
9
  super().__init__(**kwargs)
10
  self.dim = dim
 
4
 
5
 
6
  @tf.keras.utils.register_keras_serializable(package="gcvit")
7
+ class Stem(tf.keras.layers.Layer):
8
  def __init__(self, dim, **kwargs):
9
  super().__init__(**kwargs)
10
  self.dim = dim
gcvit/layers/feature.py CHANGED
@@ -1,202 +1,255 @@
1
- import tensorflow as tf
2
- import tensorflow_addons as tfa
3
-
4
- H_AXIS = -3
5
- W_AXIS = -2
6
-
7
- @tf.keras.utils.register_keras_serializable(package="gcvit")
8
- class Mlp(tf.keras.layers.Layer):
9
- def __init__(self, hidden_features=None, out_features=None, act_layer='gelu', dropout=0., **kwargs):
10
- super().__init__(**kwargs)
11
- self.hidden_features = hidden_features
12
- self.out_features = out_features
13
- self.act_layer = act_layer
14
- self.dropout = dropout
15
-
16
- def build(self, input_shape):
17
- self.in_features = input_shape[-1]
18
- self.hidden_features = self.hidden_features or self.in_features
19
- self.out_features = self.out_features or self.in_features
20
- self.fc1 = tf.keras.layers.Dense(self.hidden_features, name="fc1")
21
- self.act = tf.keras.layers.Activation(self.act_layer, name="act")
22
- self.fc2 = tf.keras.layers.Dense(self.out_features, name="fc2")
23
- self.drop1 = tf.keras.layers.Dropout(self.dropout, name="drop1")
24
- self.drop2 = tf.keras.layers.Dropout(self.dropout, name="drop2")
25
- super().build(input_shape)
26
-
27
- def call(self, inputs, **kwargs):
28
- x = self.fc1(inputs)
29
- x = self.act(x)
30
- x = self.drop1(x)
31
- x = self.fc2(x)
32
- x = self.drop2(x)
33
- return x
34
-
35
- def get_config(self):
36
- config = super().get_config()
37
- config.update({
38
- "hidden_features":self.hidden_features,
39
- "out_features":self.out_features,
40
- "act_layer":self.act_layer,
41
- "dropout":self.dropout
42
- })
43
- return config
44
-
45
- @tf.keras.utils.register_keras_serializable(package="gcvit")
46
- class SE(tf.keras.layers.Layer):
47
- def __init__(self, oup=None, expansion=0.25, **kwargs):
48
- super().__init__(**kwargs)
49
- self.expansion = expansion
50
- self.oup = oup
51
-
52
- def build(self, input_shape):
53
- inp = input_shape[-1]
54
- self.oup = self.oup or inp
55
- self.avg_pool = tfa.layers.AdaptiveAveragePooling2D(1, name="avg_pool")
56
- self.fc = [
57
- tf.keras.layers.Dense(int(inp * self.expansion), use_bias=False, name='fc/0'),
58
- tf.keras.layers.Activation('gelu', name='fc/1'),
59
- tf.keras.layers.Dense(self.oup, use_bias=False, name='fc/2'),
60
- tf.keras.layers.Activation('sigmoid', name='fc/3')
61
- ]
62
- super().build(input_shape)
63
-
64
- def call(self, inputs, **kwargs):
65
- b, _, _, c = tf.unstack(tf.shape(inputs), num=4)
66
- x = tf.reshape(self.avg_pool(inputs), (b, c))
67
- for layer in self.fc:
68
- x = layer(x)
69
- x = tf.reshape(x, (b, 1, 1, c))
70
- return x*inputs
71
-
72
- def get_config(self):
73
- config = super().get_config()
74
- config.update({
75
- 'expansion': self.expansion,
76
- 'oup': self.oup,
77
- })
78
- return config
79
-
80
- @tf.keras.utils.register_keras_serializable(package="gcvit")
81
- class ReduceSize(tf.keras.layers.Layer):
82
- def __init__(self, keep_dim=False, **kwargs):
83
- super().__init__(**kwargs)
84
- self.keep_dim = keep_dim
85
-
86
- def build(self, input_shape):
87
- dim = input_shape[-1]
88
- dim_out = dim if self.keep_dim else 2*dim
89
- self.pad1 = tf.keras.layers.ZeroPadding2D(1, name='pad1')
90
- self.pad2 = tf.keras.layers.ZeroPadding2D(1, name='pad2')
91
- self.conv = [
92
- tf.keras.layers.DepthwiseConv2D(kernel_size=3, strides=1, padding='valid', use_bias=False, name='conv/0'),
93
- tf.keras.layers.Activation('gelu', name='conv/1'),
94
- SE(name='conv/2'),
95
- tf.keras.layers.Conv2D(dim, kernel_size=1, strides=1, padding='valid', use_bias=False, name='conv/3')
96
- ]
97
- self.reduction = tf.keras.layers.Conv2D(dim_out, kernel_size=3, strides=2, padding='valid', use_bias=False,
98
- name='reduction')
99
- self.norm1 = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-05, name='norm1') # eps like PyTorch
100
- self.norm2 = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-05, name='norm2')
101
- super().build(input_shape)
102
-
103
- def call(self, inputs, **kwargs):
104
- x = self.norm1(inputs)
105
- xr = self.pad1(x) # if pad had weights it would've thrown error with .save_weights()
106
- for layer in self.conv:
107
- xr = layer(xr)
108
- x = x + xr
109
- x = self.pad2(x)
110
- x = self.reduction(x)
111
- x = self.norm2(x)
112
- return x
113
-
114
- def get_config(self):
115
- config = super().get_config()
116
- config.update({
117
- "keep_dim":self.keep_dim,
118
- })
119
- return config
120
-
121
- @tf.keras.utils.register_keras_serializable(package="gcvit")
122
- class FeatExtract(tf.keras.layers.Layer):
123
- def __init__(self, keep_dim=False, **kwargs):
124
- super().__init__(**kwargs)
125
- self.keep_dim = keep_dim
126
-
127
- def build(self, input_shape):
128
- dim = input_shape[-1]
129
- self.pad1 = tf.keras.layers.ZeroPadding2D(1, name='pad1')
130
- self.pad2 = tf.keras.layers.ZeroPadding2D(1, name='pad2')
131
- self.conv = [
132
- tf.keras.layers.DepthwiseConv2D(kernel_size=3, strides=1, padding='valid', use_bias=False, name='conv/0'),
133
- tf.keras.layers.Activation('gelu', name='conv/1'),
134
- SE(name='conv/2'),
135
- tf.keras.layers.Conv2D(dim, kernel_size=1, strides=1, padding='valid', use_bias=False, name='conv/3')
136
- ]
137
- if not self.keep_dim:
138
- self.pool = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding='valid', name='pool')
139
- # else:
140
- # self.pool = tf.keras.layers.Activation('linear', name='identity') # hack for PyTorch nn.Identity layer ;)
141
- super().build(input_shape)
142
-
143
- def call(self, inputs, **kwargs):
144
- x = inputs
145
- xr = self.pad1(x)
146
- for layer in self.conv:
147
- xr = layer(xr)
148
- x = x + xr # if pad had weights it would've thrown error with .save_weights()
149
- if not self.keep_dim:
150
- x = self.pad2(x)
151
- x = self.pool(x)
152
- return x
153
-
154
- def get_config(self):
155
- config = super().get_config()
156
- config.update({
157
- "keep_dim":self.keep_dim,
158
- })
159
- return config
160
-
161
- @tf.keras.utils.register_keras_serializable(package="gcvit")
162
- class Resizing(tf.keras.layers.Layer):
163
- def __init__(self,
164
- height,
165
- width,
166
- interpolation='bilinear',
167
- **kwargs):
168
- self.height = height
169
- self.width = width
170
- self.interpolation = interpolation
171
- super().__init__(**kwargs)
172
-
173
- def call(self, inputs):
174
- # tf.image.resize will always output float32 and operate more efficiently on
175
- # float32 unless interpolation is nearest, in which case ouput type matches
176
- # input type.
177
- if self.interpolation == 'nearest':
178
- input_dtype = self.compute_dtype
179
- else:
180
- input_dtype = tf.float32
181
- inputs = tf.cast(inputs, dtype=input_dtype)
182
- size = [self.height, self.width]
183
- outputs = tf.image.resize(
184
- inputs,
185
- size=size,
186
- method=self.interpolation)
187
- return tf.cast(outputs, self.compute_dtype)
188
-
189
- def compute_output_shape(self, input_shape):
190
- input_shape = tf.TensorShape(input_shape).as_list()
191
- input_shape[H_AXIS] = self.height
192
- input_shape[W_AXIS] = self.width
193
- return tf.TensorShape(input_shape)
194
-
195
- def get_config(self):
196
- config = super().get_config()
197
- config.update({
198
- 'height': self.height,
199
- 'width': self.width,
200
- 'interpolation': self.interpolation,
201
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  return config
 
1
+ import tensorflow as tf
2
+ import tensorflow_addons as tfa
3
+
4
+ H_AXIS = -3
5
+ W_AXIS = -2
6
+
7
+ @tf.keras.utils.register_keras_serializable(package="gcvit")
8
+ class Mlp(tf.keras.layers.Layer):
9
+ def __init__(self, hidden_features=None, out_features=None, act_layer='gelu', dropout=0., **kwargs):
10
+ super().__init__(**kwargs)
11
+ self.hidden_features = hidden_features
12
+ self.out_features = out_features
13
+ self.act_layer = act_layer
14
+ self.dropout = dropout
15
+
16
+ def build(self, input_shape):
17
+ self.in_features = input_shape[-1]
18
+ self.hidden_features = self.hidden_features or self.in_features
19
+ self.out_features = self.out_features or self.in_features
20
+ self.fc1 = tf.keras.layers.Dense(self.hidden_features, name="fc1")
21
+ self.act = tf.keras.layers.Activation(self.act_layer, name="act")
22
+ self.fc2 = tf.keras.layers.Dense(self.out_features, name="fc2")
23
+ self.drop1 = tf.keras.layers.Dropout(self.dropout, name="drop1")
24
+ self.drop2 = tf.keras.layers.Dropout(self.dropout, name="drop2")
25
+ super().build(input_shape)
26
+
27
+ def call(self, inputs, **kwargs):
28
+ x = self.fc1(inputs)
29
+ x = self.act(x)
30
+ x = self.drop1(x)
31
+ x = self.fc2(x)
32
+ x = self.drop2(x)
33
+ return x
34
+
35
+ def get_config(self):
36
+ config = super().get_config()
37
+ config.update({
38
+ "hidden_features":self.hidden_features,
39
+ "out_features":self.out_features,
40
+ "act_layer":self.act_layer,
41
+ "dropout":self.dropout
42
+ })
43
+ return config
44
+
45
+ @tf.keras.utils.register_keras_serializable(package="gcvit")
46
+ class SE(tf.keras.layers.Layer):
47
+ def __init__(self, oup=None, expansion=0.25, **kwargs):
48
+ super().__init__(**kwargs)
49
+ self.expansion = expansion
50
+ self.oup = oup
51
+
52
+ def build(self, input_shape):
53
+ inp = input_shape[-1]
54
+ self.oup = self.oup or inp
55
+ self.avg_pool = tfa.layers.AdaptiveAveragePooling2D(1, name="avg_pool")
56
+ self.fc = [
57
+ tf.keras.layers.Dense(int(inp * self.expansion), use_bias=False, name='fc/0'),
58
+ tf.keras.layers.Activation('gelu', name='fc/1'),
59
+ tf.keras.layers.Dense(self.oup, use_bias=False, name='fc/2'),
60
+ tf.keras.layers.Activation('sigmoid', name='fc/3')
61
+ ]
62
+ super().build(input_shape)
63
+
64
+ def call(self, inputs, **kwargs):
65
+ b, _, _, c = tf.unstack(tf.shape(inputs), num=4)
66
+ x = tf.reshape(self.avg_pool(inputs), (b, c))
67
+ for layer in self.fc:
68
+ x = layer(x)
69
+ x = tf.reshape(x, (b, 1, 1, c))
70
+ return x*inputs
71
+
72
+ def get_config(self):
73
+ config = super().get_config()
74
+ config.update({
75
+ 'expansion': self.expansion,
76
+ 'oup': self.oup,
77
+ })
78
+ return config
79
+
80
+ @tf.keras.utils.register_keras_serializable(package="gcvit")
81
+ class ReduceSize(tf.keras.layers.Layer):
82
+ def __init__(self, keep_dim=False, **kwargs):
83
+ super().__init__(**kwargs)
84
+ self.keep_dim = keep_dim
85
+
86
+ def build(self, input_shape):
87
+ dim = input_shape[-1]
88
+ dim_out = dim if self.keep_dim else 2*dim
89
+ self.pad1 = tf.keras.layers.ZeroPadding2D(1, name='pad1')
90
+ self.pad2 = tf.keras.layers.ZeroPadding2D(1, name='pad2')
91
+ self.conv = [
92
+ tf.keras.layers.DepthwiseConv2D(kernel_size=3, strides=1, padding='valid', use_bias=False, name='conv/0'),
93
+ tf.keras.layers.Activation('gelu', name='conv/1'),
94
+ SE(name='conv/2'),
95
+ tf.keras.layers.Conv2D(dim, kernel_size=1, strides=1, padding='valid', use_bias=False, name='conv/3')
96
+ ]
97
+ self.reduction = tf.keras.layers.Conv2D(dim_out, kernel_size=3, strides=2, padding='valid', use_bias=False,
98
+ name='reduction')
99
+ self.norm1 = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-05, name='norm1') # eps like PyTorch
100
+ self.norm2 = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-05, name='norm2')
101
+ super().build(input_shape)
102
+
103
+ def call(self, inputs, **kwargs):
104
+ x = self.norm1(inputs)
105
+ xr = self.pad1(x) # if pad had weights it would've thrown error with .save_weights()
106
+ for layer in self.conv:
107
+ xr = layer(xr)
108
+ x = x + xr
109
+ x = self.pad2(x)
110
+ x = self.reduction(x)
111
+ x = self.norm2(x)
112
+ return x
113
+
114
+ def get_config(self):
115
+ config = super().get_config()
116
+ config.update({
117
+ "keep_dim":self.keep_dim,
118
+ })
119
+ return config
120
+
121
+ @tf.keras.utils.register_keras_serializable(package="gcvit")
122
+ class FeatExtract(tf.keras.layers.Layer):
123
+ def __init__(self, keep_dim=False, **kwargs):
124
+ super().__init__(**kwargs)
125
+ self.keep_dim = keep_dim
126
+
127
+ def build(self, input_shape):
128
+ dim = input_shape[-1]
129
+ self.pad1 = tf.keras.layers.ZeroPadding2D(1, name='pad1')
130
+ self.pad2 = tf.keras.layers.ZeroPadding2D(1, name='pad2')
131
+ self.conv = [
132
+ tf.keras.layers.DepthwiseConv2D(kernel_size=3, strides=1, padding='valid', use_bias=False, name='conv/0'),
133
+ tf.keras.layers.Activation('gelu', name='conv/1'),
134
+ SE(name='conv/2'),
135
+ tf.keras.layers.Conv2D(dim, kernel_size=1, strides=1, padding='valid', use_bias=False, name='conv/3')
136
+ ]
137
+ if not self.keep_dim:
138
+ self.pool = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding='valid', name='pool')
139
+ # else:
140
+ # self.pool = tf.keras.layers.Activation('linear', name='identity') # hack for PyTorch nn.Identity layer ;)
141
+ super().build(input_shape)
142
+
143
+ def call(self, inputs, **kwargs):
144
+ x = inputs
145
+ xr = self.pad1(x)
146
+ for layer in self.conv:
147
+ xr = layer(xr)
148
+ x = x + xr # if pad had weights it would've thrown error with .save_weights()
149
+ if not self.keep_dim:
150
+ x = self.pad2(x)
151
+ x = self.pool(x)
152
+ return x
153
+
154
+ def get_config(self):
155
+ config = super().get_config()
156
+ config.update({
157
+ "keep_dim":self.keep_dim,
158
+ })
159
+ return config
160
+
161
+ @tf.keras.utils.register_keras_serializable(package="gcvit")
162
+ class GlobalQueryGen(tf.keras.layers.Layer):
163
+ """
164
+ Global query generator based on: "Hatamizadeh et al.,
165
+ Global Context Vision Transformers <https://arxiv.org/abs/2206.09959>"
166
+ """
167
+ def __init__(self, keep_dims=False, **kwargs):
168
+ super().__init__(**kwargs)
169
+ self.keep_dims = keep_dims
170
+
171
+ def build(self, input_shape):
172
+ self.to_q_global = [FeatExtract(keep_dim, name=f'to_q_global/{i}') \
173
+ for i, keep_dim in enumerate(self.keep_dims)]
174
+ super().build(input_shape)
175
+
176
+ def call(self, inputs, **kwargs):
177
+ x = inputs
178
+ for layer in self.to_q_global:
179
+ x = layer(x)
180
+ return x
181
+
182
+ def get_config(self):
183
+ config = super().get_config()
184
+ config.update({
185
+ "keep_dims":self.keep_dims,
186
+ })
187
+ return config
188
+
189
+ @tf.keras.utils.register_keras_serializable(package="gcvit")
190
+ class Resizing(tf.keras.layers.Layer):
191
+ def __init__(self,
192
+ height,
193
+ width,
194
+ interpolation='bilinear',
195
+ **kwargs):
196
+ self.height = height
197
+ self.width = width
198
+ self.interpolation = interpolation
199
+ super().__init__(**kwargs)
200
+
201
+ def call(self, inputs):
202
+ # tf.image.resize will always output float32 and operate more efficiently on
203
+ # float32 unless interpolation is nearest, in which case ouput type matches
204
+ # input type.
205
+ if self.interpolation == 'nearest':
206
+ input_dtype = self.compute_dtype
207
+ else:
208
+ input_dtype = tf.float32
209
+ inputs = tf.cast(inputs, dtype=input_dtype)
210
+ size = [self.height, self.width]
211
+ outputs = tf.image.resize(
212
+ inputs,
213
+ size=size,
214
+ method=self.interpolation)
215
+ return tf.cast(outputs, self.compute_dtype)
216
+
217
+ def compute_output_shape(self, input_shape):
218
+ input_shape = tf.TensorShape(input_shape).as_list()
219
+ input_shape[H_AXIS] = self.height
220
+ input_shape[W_AXIS] = self.width
221
+ return tf.TensorShape(input_shape)
222
+
223
+ def get_config(self):
224
+ config = super().get_config()
225
+ config.update({
226
+ 'height': self.height,
227
+ 'width': self.width,
228
+ 'interpolation': self.interpolation,
229
+ })
230
+ return config
231
+
232
+ @tf.keras.utils.register_keras_serializable(package="gcvit")
233
+ class FitWindow(tf.keras.layers.Layer):
234
+ "Pad feature to fit window"
235
+ def __init__(self, window_size, **kwargs):
236
+ super().__init__(**kwargs)
237
+ self.window_size = window_size
238
+
239
+ def call(self, inputs):
240
+ B, H, W, C = tf.unstack(tf.shape(inputs), num=4)
241
+ # pad to multiple of window_size
242
+ h_pad = (self.window_size - H % self.window_size) % self.window_size
243
+ w_pad = (self.window_size - W % self.window_size) % self.window_size
244
+ x = tf.pad(inputs, [[0, 0],
245
+ [h_pad//2, (h_pad//2 + h_pad%2)], # padding in both directions unlike tfgcvit
246
+ [w_pad//2, (w_pad//2 + w_pad%2)],
247
+ [0, 0]])
248
+ return x
249
+
250
+ def get_config(self):
251
+ config = super().get_config()
252
+ config.update({
253
+ 'window_size': self.window_size,
254
+ })
255
  return config
gcvit/layers/level.py CHANGED
@@ -1,93 +1,85 @@
1
- import tensorflow as tf
2
-
3
- from .feature import FeatExtract, ReduceSize, Resizing
4
- from .block import GCViTBlock
5
-
6
- @tf.keras.utils.register_keras_serializable(package="gcvit")
7
- class GCViTLayer(tf.keras.layers.Layer):
8
- def __init__(self, depth, num_heads, window_size, keep_dims, downsample=True, mlp_ratio=4., qkv_bias=True,
9
- qk_scale=None, drop=0., attn_drop=0., path_drop=0., layer_scale=None, resize_query=False, **kwargs):
10
- super().__init__(**kwargs)
11
- self.depth = depth
12
- self.num_heads = num_heads
13
- self.window_size = window_size
14
- self.keep_dims = keep_dims
15
- self.downsample = downsample
16
- self.mlp_ratio = mlp_ratio
17
- self.qkv_bias = qkv_bias
18
- self.qk_scale = qk_scale
19
- self.drop = drop
20
- self.attn_drop = attn_drop
21
- self.path_drop = path_drop
22
- self.layer_scale = layer_scale
23
- self.resize_query = resize_query
24
-
25
- def build(self, input_shape):
26
- path_drop = [self.path_drop] * self.depth if not isinstance(self.path_drop, list) else self.path_drop
27
- self.blocks = [
28
- GCViTBlock(window_size=self.window_size,
29
- num_heads=self.num_heads,
30
- global_query=bool(i % 2),
31
- mlp_ratio=self.mlp_ratio,
32
- qkv_bias=self.qkv_bias,
33
- qk_scale=self.qk_scale,
34
- drop=self.drop,
35
- attn_drop=self.attn_drop,
36
- path_drop=path_drop[i],
37
- layer_scale=self.layer_scale,
38
- name=f'blocks/{i}')
39
- for i in range(self.depth)]
40
- self.down = ReduceSize(keep_dim=False, name='downsample')
41
- self.to_q_global = [
42
- FeatExtract(keep_dim, name=f'to_q_global/{i}')
43
- for i, keep_dim in enumerate(self.keep_dims)]
44
- self.resize = Resizing(self.window_size, self.window_size, interpolation='bicubic')
45
- super().build(input_shape)
46
-
47
- def call(self, inputs, **kwargs):
48
- height, width = tf.unstack(tf.shape(inputs)[1:3], num=2)
49
- # pad to multiple of window_size
50
- h_pad = (self.window_size - height % self.window_size) % self.window_size
51
- w_pad = (self.window_size - width % self.window_size) % self.window_size
52
- x = tf.pad(inputs, [[0, 0],
53
- [h_pad//2, (h_pad//2 + h_pad%2)], # padding in both directions unlike tfgcvit
54
- [w_pad//2, (w_pad//2 + w_pad%2)],
55
- [0, 0]])
56
- # generate global query
57
- q_global = x # (B, H, W, C)
58
- for layer in self.to_q_global:
59
- q_global = layer(q_global) # official impl issue: https://github.com/NVlabs/GCVit/issues/13
60
- # resize query to fit key-value, but result in poor score with official weights?
61
- if self.resize_query:
62
- q_global = self.resize(q_global) # to avoid mismatch between feat_map and q_global: https://github.com/NVlabs/GCVit/issues/9
63
- # feature_map -> windows -> window_attention -> feature_map
64
- for i, blk in enumerate(self.blocks):
65
- if i % 2:
66
- x = blk([x, q_global])
67
- else:
68
- x = blk([x])
69
- x = x[:, :height, :width, :] # https://github.com/NVlabs/GCVit/issues/9
70
- # set shape for [B, ?, ?, C]
71
- x.set_shape(inputs.shape) # `tf.reshape` creates new tensor with new_shape
72
- # downsample
73
- if self.downsample:
74
- x = self.down(x)
75
- return x
76
-
77
- def get_config(self):
78
- config = super().get_config()
79
- config.update({
80
- 'depth': self.depth,
81
- 'num_heads': self.num_heads,
82
- 'window_size': self.window_size,
83
- 'keep_dims': self.keep_dims,
84
- 'downsample': self.downsample,
85
- 'mlp_ratio': self.mlp_ratio,
86
- 'qkv_bias': self.qkv_bias,
87
- 'qk_scale': self.qk_scale,
88
- 'drop': self.drop,
89
- 'attn_drop': self.attn_drop,
90
- 'path_drop': self.path_drop,
91
- 'layer_scale': self.layer_scale
92
- })
93
  return config
 
1
+ import tensorflow as tf
2
+
3
+ from .feature import GlobalQueryGen, ReduceSize, Resizing, FitWindow
4
+ from .block import GCViTBlock
5
+
6
+ @tf.keras.utils.register_keras_serializable(package="gcvit")
7
+ class GCViTLevel(tf.keras.layers.Layer):
8
+ def __init__(self, depth, num_heads, window_size, keep_dims, downsample=True, mlp_ratio=4., qkv_bias=True,
9
+ qk_scale=None, drop=0., attn_drop=0., path_drop=0., layer_scale=None, resize_query=False, **kwargs):
10
+ super().__init__(**kwargs)
11
+ self.depth = depth
12
+ self.num_heads = num_heads
13
+ self.window_size = window_size
14
+ self.keep_dims = keep_dims
15
+ self.downsample = downsample
16
+ self.mlp_ratio = mlp_ratio
17
+ self.qkv_bias = qkv_bias
18
+ self.qk_scale = qk_scale
19
+ self.drop = drop
20
+ self.attn_drop = attn_drop
21
+ self.path_drop = path_drop
22
+ self.layer_scale = layer_scale
23
+ self.resize_query = resize_query
24
+
25
+ def build(self, input_shape):
26
+ path_drop = [self.path_drop] * self.depth if not isinstance(self.path_drop, list) else self.path_drop
27
+ self.blocks = [
28
+ GCViTBlock(window_size=self.window_size,
29
+ num_heads=self.num_heads,
30
+ global_query=bool(i % 2),
31
+ mlp_ratio=self.mlp_ratio,
32
+ qkv_bias=self.qkv_bias,
33
+ qk_scale=self.qk_scale,
34
+ drop=self.drop,
35
+ attn_drop=self.attn_drop,
36
+ path_drop=path_drop[i],
37
+ layer_scale=self.layer_scale,
38
+ name=f'blocks/{i}')
39
+ for i in range(self.depth)]
40
+ self.down = ReduceSize(keep_dim=False, name='downsample')
41
+ self.q_global_gen = GlobalQueryGen(self.keep_dims, name='q_global_gen')
42
+ self.resize = Resizing(self.window_size, self.window_size, interpolation='bicubic')
43
+ self.fit_window = FitWindow(self.window_size)
44
+ super().build(input_shape)
45
+
46
+ def call(self, inputs, **kwargs):
47
+ H, W = tf.unstack(tf.shape(inputs)[1:3], num=2)
48
+ # pad to fit window_size
49
+ x = self.fit_window(inputs)
50
+ # generate global query
51
+ q_global = self.q_global_gen(x) # (B, H, W, C) # official impl issue: https://github.com/NVlabs/GCVit/issues/13
52
+ # resize query to fit key-value, but result in poor score with official weights?
53
+ if self.resize_query:
54
+ q_global = self.resize(q_global) # to avoid mismatch between feat_map and q_global: https://github.com/NVlabs/GCVit/issues/9
55
+ # feature_map -> windows -> window_attention -> feature_map
56
+ for i, blk in enumerate(self.blocks):
57
+ if i % 2:
58
+ x = blk([x, q_global])
59
+ else:
60
+ x = blk([x])
61
+ x = x[:, :H, :W, :] # https://github.com/NVlabs/GCVit/issues/9
62
+ # set shape for [B, ?, ?, C]
63
+ x.set_shape(inputs.shape) # `tf.reshape` creates new tensor with new_shape
64
+ # downsample
65
+ if self.downsample:
66
+ x = self.down(x)
67
+ return x
68
+
69
+ def get_config(self):
70
+ config = super().get_config()
71
+ config.update({
72
+ 'depth': self.depth,
73
+ 'num_heads': self.num_heads,
74
+ 'window_size': self.window_size,
75
+ 'keep_dims': self.keep_dims,
76
+ 'downsample': self.downsample,
77
+ 'mlp_ratio': self.mlp_ratio,
78
+ 'qkv_bias': self.qkv_bias,
79
+ 'qk_scale': self.qk_scale,
80
+ 'drop': self.drop,
81
+ 'attn_drop': self.attn_drop,
82
+ 'path_drop': self.path_drop,
83
+ 'layer_scale': self.layer_scale
84
+ })
 
 
 
 
 
 
 
 
85
  return config
gcvit/models/__init__.py CHANGED
@@ -1 +1 @@
1
- from .gcvit import GCViT, GCViTTiny, GCViTSmall, GCViTBase
 
1
+ from .gcvit import GCViT, GCViTXXTiny, GCViTXTiny, GCViTTiny, GCViTSmall, GCViTBase
gcvit/models/__pycache__/__init__.cpython-38.pyc DELETED
Binary file (234 Bytes)
 
gcvit/models/__pycache__/gcvit.cpython-38.pyc DELETED
Binary file (4.08 kB)
 
gcvit/models/gcvit.py CHANGED
@@ -1,145 +1,180 @@
1
- import numpy as np
2
- import tensorflow as tf
3
-
4
- from ..layers import PatchEmbed, GCViTLayer, Identity
5
-
6
-
7
- BASE_URL = 'https://github.com/awsaf49/gcvit-tf/releases/download'
8
- TAG = 'v1.0.4'
9
- NAME2CONFIG = {
10
- 'gcvit_tiny': {'window_size': (7, 7, 14, 7),
11
- 'dim': 64,
12
- 'depths': (3, 4, 19, 5),
13
- 'num_heads': (2, 4, 8, 16),
14
- 'path_drop': 0.2,},
15
- 'gcvit_small': {'window_size': (7, 7, 14, 7),
16
- 'dim': 96,
17
- 'depths': (3, 4, 19, 5),
18
- 'num_heads': (3, 6, 12, 24),
19
- 'mlp_ratio': 2.,
20
- 'path_drop': 0.3,
21
- 'layer_scale': 1e-5,},
22
- 'gcvit_base': {'window_size': (7, 7, 14, 7),
23
- 'dim':128,
24
- 'depths': (3, 4, 19, 5),
25
- 'num_heads': (4, 8, 16, 32),
26
- 'mlp_ratio': 2.,
27
- 'path_drop': 0.5,
28
- 'layer_scale': 1e-5,},
29
- }
30
-
31
- @tf.keras.utils.register_keras_serializable(package='gcvit')
32
- class GCViT(tf.keras.Model):
33
- def __init__(self, window_size, dim, depths, num_heads,
34
- drop_rate=0., mlp_ratio=3., qkv_bias=True, qk_scale=None, attn_drop=0., path_drop=0.1, layer_scale=None, resize_query=False,
35
- global_pool='avg', num_classes=1000, head_act='softmax', **kwargs):
36
- super().__init__(**kwargs)
37
- self.window_size = window_size
38
- self.dim = dim
39
- self.depths = depths
40
- self.num_heads = num_heads
41
- self.drop_rate = drop_rate
42
- self.mlp_ratio = mlp_ratio
43
- self.qkv_bias = qkv_bias
44
- self.qk_scale = qk_scale
45
- self.attn_drop = attn_drop
46
- self.path_drop = path_drop
47
- self.layer_scale = layer_scale
48
- self.resize_query = resize_query
49
- self.global_pool = global_pool
50
- self.num_classes = num_classes
51
- self.head_act = head_act
52
-
53
- self.patch_embed = PatchEmbed(dim=dim, name='patch_embed')
54
- self.pos_drop = tf.keras.layers.Dropout(drop_rate, name='pos_drop')
55
- path_drops = np.linspace(0., path_drop, sum(depths))
56
- keep_dims = [(False, False, False),(False, False),(True,),(True,),]
57
- self.levels = []
58
- for i in range(len(depths)):
59
- path_drop = path_drops[sum(depths[:i]):sum(depths[:i + 1])].tolist()
60
- level = GCViTLayer(depth=depths[i], num_heads=num_heads[i], window_size=window_size[i], keep_dims=keep_dims[i],
61
- downsample=(i < len(depths) - 1), mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
62
- drop=drop_rate, attn_drop=attn_drop, path_drop=path_drop, layer_scale=layer_scale, resize_query=resize_query,
63
- name=f'levels/{i}')
64
- self.levels.append(level)
65
- self.norm = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-05, name='norm')
66
- if global_pool == 'avg':
67
- self.pool = tf.keras.layers.GlobalAveragePooling2D(name='pool')
68
- elif global_pool == 'max':
69
- self.pool = tf.keras.layers.GlobalMaxPooling2D(name='pool')
70
- elif global_pool is None:
71
- self.pool = Identity(name='pool')
72
- else:
73
- raise ValueError(f'Expecting pooling to be one of None/avg/max. Found: {global_pool}')
74
- self.head = [tf.keras.layers.Dense(num_classes, name='head/fc'),
75
- tf.keras.layers.Activation(head_act, name='head/act')]
76
-
77
- def reset_classifier(self, num_classes, head_act, global_pool=None):
78
- self.num_classes = num_classes
79
- if global_pool is not None:
80
- self.global_pool = global_pool
81
- self.head[0] = tf.keras.layers.Dense(num_classes, name='head/fc') if num_classes else Identity(name='head/fc')
82
- self.head[1] = tf.keras.layers.Activation(head_act, name='head/act') if head_act else Identity(name='head/act')
83
- super().build((1, 224, 224, 3))
84
-
85
- def forward_features(self, inputs):
86
- x = self.patch_embed(inputs)
87
- x = self.pos_drop(x)
88
- x = tf.cast(x, dtype=tf.float32)
89
- for level in self.levels:
90
- x = level(x)
91
- x = self.norm(x)
92
- return x
93
-
94
- def forward_head(self, inputs, pre_logits=False):
95
- x = inputs
96
- if self.global_pool in ['avg', 'max']:
97
- x = self.pool(x)
98
- if not pre_logits:
99
- for layer in self.head:
100
- x = layer(x)
101
- return x
102
-
103
- def call(self, inputs, **kwargs):
104
- x = self.forward_features(inputs)
105
- x = self.forward_head(x)
106
- return x
107
-
108
- def build_graph(self, input_shape=(224, 224, 3)):
109
- """https://www.kaggle.com/code/ipythonx/tf-hybrid-efficientnet-swin-transformer-gradcam"""
110
- x = tf.keras.Input(shape=input_shape)
111
- return tf.keras.Model(inputs=[x], outputs=self.call(x), name=self.name)
112
-
113
- # load standard models
114
- def GCViTTiny(pretrain=False, **kwargs):
115
- name = 'gcvit_tiny'
116
- config = NAME2CONFIG[name]
117
- ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
118
- model = GCViT(name=name, **config, **kwargs)
119
- model(tf.random.uniform(shape=(1, 224, 224, 3)))
120
- if pretrain:
121
- ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
122
- model.load_weights(ckpt_path)
123
- return model
124
-
125
- def GCViTSmall(pretrain=False, **kwargs):
126
- name = 'gcvit_small'
127
- config = NAME2CONFIG[name]
128
- ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
129
- model = GCViT(name=name, **config, **kwargs)
130
- model(tf.random.uniform(shape=(1, 224, 224, 3)))
131
- if pretrain:
132
- ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
133
- model.load_weights(ckpt_path)
134
- return model
135
-
136
- def GCViTBase(pretrain=False, **kwargs):
137
- name = 'gcvit_base'
138
- config = NAME2CONFIG[name]
139
- ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
140
- model = GCViT(name=name, **config, **kwargs)
141
- model(tf.random.uniform(shape=(1, 224, 224, 3)))
142
- if pretrain:
143
- ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
144
- model.load_weights(ckpt_path)
145
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+
4
+ from ..layers import Stem, GCViTLevel, Identity
5
+
6
+
7
+ BASE_URL = 'https://github.com/awsaf49/gcvit-tf/releases/download'
8
+ TAG = 'v1.0.9'
9
+ NAME2CONFIG = {
10
+ 'gcvit_xxtiny': {'window_size': (7, 7, 14, 7),
11
+ 'dim': 64,
12
+ 'depths': (2, 2, 6, 2),
13
+ 'num_heads': (2, 4, 8, 16),
14
+ 'mlp_ratio': 3.,
15
+ 'path_drop': 0.2},
16
+ 'gcvit_xtiny': {'window_size': (7, 7, 14, 7),
17
+ 'dim': 64,
18
+ 'depths': (3, 4, 6, 5),
19
+ 'num_heads': (2, 4, 8, 16),
20
+ 'mlp_ratio': 3.,
21
+ 'path_drop': 0.2},
22
+ 'gcvit_tiny': {'window_size': (7, 7, 14, 7),
23
+ 'dim': 64,
24
+ 'depths': (3, 4, 19, 5),
25
+ 'num_heads': (2, 4, 8, 16),
26
+ 'mlp_ratio': 3.,
27
+ 'path_drop': 0.2,},
28
+ 'gcvit_small': {'window_size': (7, 7, 14, 7),
29
+ 'dim': 96,
30
+ 'depths': (3, 4, 19, 5),
31
+ 'num_heads': (3, 6, 12, 24),
32
+ 'mlp_ratio': 2.,
33
+ 'path_drop': 0.3,
34
+ 'layer_scale': 1e-5,},
35
+ 'gcvit_base': {'window_size': (7, 7, 14, 7),
36
+ 'dim':128,
37
+ 'depths': (3, 4, 19, 5),
38
+ 'num_heads': (4, 8, 16, 32),
39
+ 'mlp_ratio': 2.,
40
+ 'path_drop': 0.5,
41
+ 'layer_scale': 1e-5,},
42
+ }
43
+
44
+ @tf.keras.utils.register_keras_serializable(package='gcvit')
45
+ class GCViT(tf.keras.Model):
46
+ def __init__(self, window_size, dim, depths, num_heads,
47
+ drop_rate=0., mlp_ratio=3., qkv_bias=True, qk_scale=None, attn_drop=0., path_drop=0.1, layer_scale=None, resize_query=False,
48
+ global_pool='avg', num_classes=1000, head_act='softmax', **kwargs):
49
+ super().__init__(**kwargs)
50
+ self.window_size = window_size
51
+ self.dim = dim
52
+ self.depths = depths
53
+ self.num_heads = num_heads
54
+ self.drop_rate = drop_rate
55
+ self.mlp_ratio = mlp_ratio
56
+ self.qkv_bias = qkv_bias
57
+ self.qk_scale = qk_scale
58
+ self.attn_drop = attn_drop
59
+ self.path_drop = path_drop
60
+ self.layer_scale = layer_scale
61
+ self.resize_query = resize_query
62
+ self.global_pool = global_pool
63
+ self.num_classes = num_classes
64
+ self.head_act = head_act
65
+
66
+ self.patch_embed = Stem(dim=dim, name='patch_embed')
67
+ self.pos_drop = tf.keras.layers.Dropout(drop_rate, name='pos_drop')
68
+ path_drops = np.linspace(0., path_drop, sum(depths))
69
+ keep_dims = [(False, False, False),(False, False),(True,),(True,),]
70
+ self.levels = []
71
+ for i in range(len(depths)):
72
+ path_drop = path_drops[sum(depths[:i]):sum(depths[:i + 1])].tolist()
73
+ level = GCViTLevel(depth=depths[i], num_heads=num_heads[i], window_size=window_size[i], keep_dims=keep_dims[i],
74
+ downsample=(i < len(depths) - 1), mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
75
+ drop=drop_rate, attn_drop=attn_drop, path_drop=path_drop, layer_scale=layer_scale, resize_query=resize_query,
76
+ name=f'levels/{i}')
77
+ self.levels.append(level)
78
+ self.norm = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-05, name='norm')
79
+ if global_pool == 'avg':
80
+ self.pool = tf.keras.layers.GlobalAveragePooling2D(name='pool')
81
+ elif global_pool == 'max':
82
+ self.pool = tf.keras.layers.GlobalMaxPooling2D(name='pool')
83
+ elif global_pool is None:
84
+ self.pool = Identity(name='pool')
85
+ else:
86
+ raise ValueError(f'Expecting pooling to be one of None/avg/max. Found: {global_pool}')
87
+ self.head = tf.keras.layers.Dense(num_classes, name='head', activation=head_act)
88
+
89
+ def reset_classifier(self, num_classes, head_act, global_pool=None, in_channels=3):
90
+ self.num_classes = num_classes
91
+ if global_pool is not None:
92
+ self.global_pool = global_pool
93
+ self.head = tf.keras.layers.Dense(num_classes, name='head', activation=head_act) if num_classes else Identity(name='head')
94
+ super().build((1, 224, 224, in_channels)) # for head we only need info from the input channel
95
+
96
+ def forward_features(self, inputs):
97
+ x = self.patch_embed(inputs)
98
+ x = self.pos_drop(x)
99
+ x = tf.cast(x, dtype=tf.float32)
100
+ for level in self.levels:
101
+ x = level(x)
102
+ x = self.norm(x)
103
+ return x
104
+
105
+ def forward_head(self, inputs, pre_logits=False):
106
+ x = inputs
107
+ if self.global_pool in ['avg', 'max']:
108
+ x = self.pool(x)
109
+ if not pre_logits:
110
+ x = self.head(x)
111
+ return x
112
+
113
+ def call(self, inputs, **kwargs):
114
+ x = self.forward_features(inputs)
115
+ x = self.forward_head(x)
116
+ return x
117
+
118
+ def build_graph(self, input_shape=(224, 224, 3)):
119
+ """https://www.kaggle.com/code/ipythonx/tf-hybrid-efficientnet-swin-transformer-gradcam"""
120
+ x = tf.keras.Input(shape=input_shape)
121
+ return tf.keras.Model(inputs=[x], outputs=self.call(x), name=self.name)
122
+
123
+ def summary(self, input_shape=(224, 224, 3)):
124
+ return self.build_graph(input_shape).summary()
125
+
126
+ # load standard models
127
+ def GCViTXXTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
128
+ name = 'gcvit_xxtiny'
129
+ config = NAME2CONFIG[name]
130
+ ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
131
+ model = GCViT(name=name, resize_query=resize_query, **config, **kwargs)
132
+ model(tf.random.uniform(shape=input_shape)[tf.newaxis,])
133
+ if pretrain:
134
+ ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
135
+ model.load_weights(ckpt_path)
136
+ return model
137
+
138
+ def GCViTXTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
139
+ name = 'gcvit_xtiny'
140
+ config = NAME2CONFIG[name]
141
+ ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
142
+ model = GCViT(name=name, resize_query=resize_query, **config, **kwargs)
143
+ model(tf.random.uniform(shape=input_shape)[tf.newaxis,])
144
+ if pretrain:
145
+ ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
146
+ model.load_weights(ckpt_path)
147
+ return model
148
+
149
+ def GCViTTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
150
+ name = 'gcvit_tiny'
151
+ config = NAME2CONFIG[name]
152
+ ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
153
+ model = GCViT(name=name, resize_query=resize_query, **config, **kwargs)
154
+ model(tf.random.uniform(shape=input_shape)[tf.newaxis,])
155
+ if pretrain:
156
+ ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
157
+ model.load_weights(ckpt_path)
158
+ return model
159
+
160
+ def GCViTSmall(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
161
+ name = 'gcvit_small'
162
+ config = NAME2CONFIG[name]
163
+ ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
164
+ model = GCViT(name=name, resize_query=resize_query, **config, **kwargs)
165
+ model(tf.random.uniform(shape=input_shape)[tf.newaxis,])
166
+ if pretrain:
167
+ ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
168
+ model.load_weights(ckpt_path)
169
+ return model
170
+
171
+ def GCViTBase(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
172
+ name = 'gcvit_base'
173
+ config = NAME2CONFIG[name]
174
+ ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
175
+ model = GCViT(name=name, resize_query=resize_query, **config, **kwargs)
176
+ model(tf.random.uniform(shape=input_shape)[tf.newaxis,])
177
+ if pretrain:
178
+ ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
179
+ model.load_weights(ckpt_path)
180
+ return model
gcvit/utils/gradcam.py CHANGED
@@ -1,69 +1,69 @@
1
- import tensorflow as tf
2
- import matplotlib.cm as cm
3
- import numpy as np
4
- try:
5
- from tensorflow.keras.utils import array_to_img, img_to_array
6
- except:
7
- from tensorflow.keras.preprocessing.image import array_to_img, img_to_array
8
-
9
- def process_image(img, size=(224, 224)):
10
- img_array = tf.keras.applications.imagenet_utils.preprocess_input(img, mode='torch')
11
- img_array = tf.image.resize(img_array, size,)[None,]
12
- return img_array
13
-
14
- def get_gradcam_model(model):
15
- inp = tf.keras.Input(shape=(224, 224, 3))
16
- feats = model.forward_features(inp)
17
- preds = model.forward_head(feats)
18
- return tf.keras.models.Model(inp, [preds, feats])
19
-
20
- def get_gradcam_prediction(img, grad_model, process=True, decode=True, pred_index=None, cmap='jet', alpha=0.4):
21
- """Grad-CAM for a single image
22
-
23
- Args:
24
- img (np.ndarray): process or raw image without batch_shape e.g. (224, 224, 3)
25
- grad_model (tf.keras.Model): model with feature map and prediction
26
- process (bool, optional): imagenet pre-processing. Defaults to True.
27
- pred_index (int, optional): for particular calss. Defaults to None.
28
- cmap (str, optional): colormap. Defaults to 'jet'.
29
- alpha (float, optional): opacity. Defaults to 0.4.
30
-
31
- Returns:
32
- preds_decode: top5 predictions
33
- heatmap: gradcam heatmap
34
- """
35
- # process image for inference
36
- if process:
37
- img_array = process_image(img)
38
- else:
39
- img_array = tf.convert_to_tensor(img)[None,]
40
- if img.min()!=img.max():
41
- img = (img - img.min())/(img.max() - img.min())
42
- img = np.uint8(img*255.0)
43
- # get prediction
44
- with tf.GradientTape(persistent=True) as tape:
45
- preds, feats = grad_model(img_array)
46
- if pred_index is None:
47
- pred_index = tf.argmax(preds[0])
48
- class_channel = preds[:, pred_index]
49
- # compute heatmap
50
- grads = tape.gradient(class_channel, feats)
51
- pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
52
- feats = feats[0]
53
- heatmap = feats @ pooled_grads[..., tf.newaxis]
54
- heatmap = tf.squeeze(heatmap)
55
- heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
56
- heatmap = heatmap.numpy()
57
- heatmap = np.uint8(255 * heatmap)
58
- # colorize heatmap
59
- cmap = cm.get_cmap(cmap)
60
- colors = cmap(np.arange(256))[:, :3]
61
- heatmap = colors[heatmap]
62
- heatmap = array_to_img(heatmap)
63
- heatmap = heatmap.resize((img.shape[1], img.shape[0]))
64
- heatmap = img_to_array(heatmap)
65
- overlay = img + heatmap * alpha
66
- overlay = array_to_img(overlay)
67
- # decode prediction
68
- preds_decode = tf.keras.applications.imagenet_utils.decode_predictions(preds.numpy())[0] if decode else preds
69
  return preds_decode, overlay
 
1
+ import tensorflow as tf
2
+ import matplotlib.cm as cm
3
+ import numpy as np
4
+ try:
5
+ from tensorflow.keras.utils import array_to_img, img_to_array
6
+ except:
7
+ from tensorflow.keras.preprocessing.image import array_to_img, img_to_array
8
+
9
+ def process_image(img, size=(224, 224)):
10
+ img_array = tf.keras.applications.imagenet_utils.preprocess_input(img, mode='torch')
11
+ img_array = tf.image.resize(img_array, size,)[None,]
12
+ return img_array
13
+
14
+ def get_gradcam_model(model):
15
+ inp = tf.keras.Input(shape=(224, 224, 3))
16
+ feats = model.forward_features(inp)
17
+ preds = model.forward_head(feats)
18
+ return tf.keras.models.Model(inp, [preds, feats])
19
+
20
+ def get_gradcam_prediction(img, grad_model, process=True, decode=True, pred_index=None, cmap='jet', alpha=0.4):
21
+ """Grad-CAM for a single image
22
+
23
+ Args:
24
+ img (np.ndarray): process or raw image without batch_shape e.g. (224, 224, 3)
25
+ grad_model (tf.keras.Model): model with feature map and prediction
26
+ process (bool, optional): imagenet pre-processing. Defaults to True.
27
+ pred_index (int, optional): for particular calss. Defaults to None.
28
+ cmap (str, optional): colormap. Defaults to 'jet'.
29
+ alpha (float, optional): opacity. Defaults to 0.4.
30
+
31
+ Returns:
32
+ preds_decode: top5 predictions
33
+ heatmap: gradcam heatmap
34
+ """
35
+ # process image for inference
36
+ if process:
37
+ img_array = process_image(img)
38
+ else:
39
+ img_array = tf.convert_to_tensor(img)[None,]
40
+ if img.min()!=img.max():
41
+ img = (img - img.min())/(img.max() - img.min())
42
+ img = np.uint8(img*255.0)
43
+ # get prediction
44
+ with tf.GradientTape(persistent=True) as tape:
45
+ preds, feats = grad_model(img_array)
46
+ if pred_index is None:
47
+ pred_index = tf.argmax(preds[0])
48
+ class_channel = preds[:, pred_index]
49
+ # compute heatmap
50
+ grads = tape.gradient(class_channel, feats)
51
+ pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
52
+ feats = feats[0]
53
+ heatmap = feats @ pooled_grads[..., tf.newaxis]
54
+ heatmap = tf.squeeze(heatmap)
55
+ heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
56
+ heatmap = heatmap.numpy()
57
+ heatmap = np.uint8(255 * heatmap)
58
+ # colorize heatmap
59
+ cmap = cm.get_cmap(cmap)
60
+ colors = cmap(np.arange(256))[:, :3]
61
+ heatmap = colors[heatmap]
62
+ heatmap = array_to_img(heatmap)
63
+ heatmap = heatmap.resize((img.shape[1], img.shape[0]))
64
+ heatmap = img_to_array(heatmap)
65
+ overlay = img + heatmap * alpha
66
+ overlay = array_to_img(overlay)
67
+ # decode prediction
68
+ preds_decode = tf.keras.applications.imagenet_utils.decode_predictions(preds.numpy())[0] if decode else preds
69
  return preds_decode, overlay
gcvit/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "1.0.3"
 
1
+ __version__ = "1.0.9"
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
- tensorflow==2.4.1
2
- tensorflow_addons==0.14.0
3
- gradio==3.1.0
4
- numpy
5
  matplotlib
 
1
+ tensorflow==2.4.1
2
+ tensorflow_addons==0.14.0
3
+ gradio==3.1.0
4
+ numpy
5
  matplotlib
setup.py CHANGED
@@ -1,50 +1,50 @@
1
- from setuptools import setup, find_packages
2
- from codecs import open
3
- from os import path
4
-
5
- here = path.abspath(path.dirname(__file__))
6
-
7
- # Get the long description from the README file
8
- with open(path.join(here, "README.md"), encoding="utf-8") as f:
9
- long_description = f.read()
10
-
11
- with open(path.join(here, 'requirements.txt')) as f:
12
- install_requires = [x for x in f.read().splitlines() if len(x)]
13
-
14
- exec(open("gcvit/version.py").read())
15
-
16
- setup(
17
- name="gcvit",
18
- version=__version__,
19
- description="Tensorflow 2.0 Implementation of GCViT: Global Context Vision Transformer. https://github.com/awsaf49/gcvit-tf",
20
- long_description=long_description,
21
- long_description_content_type="text/markdown",
22
- url="https://github.com/awsaf49/gcvit-tf",
23
- author="Awsaf",
24
- author_email="[email protected]",
25
- classifiers=[
26
- # How mature is this project? Common values are
27
- # 3 - Alpha
28
- # 4 - Beta
29
- # 5 - Production/Stable
30
- "Development Status :: 3 - Alpha",
31
- "Intended Audience :: Developers",
32
- "Intended Audience :: Science/Research",
33
- "License :: OSI Approved :: Apache Software License",
34
- "Programming Language :: Python :: 3.6",
35
- "Programming Language :: Python :: 3.7",
36
- "Programming Language :: Python :: 3.8",
37
- "Topic :: Scientific/Engineering",
38
- "Topic :: Scientific/Engineering :: Artificial Intelligence",
39
- "Topic :: Software Development",
40
- "Topic :: Software Development :: Libraries",
41
- "Topic :: Software Development :: Libraries :: Python Modules",
42
- ],
43
- # Note that this is a string of words separated by whitespace, not a list.
44
- keywords="tensorflow computer_vision image classification transformer",
45
- packages=find_packages(exclude=["tests"]),
46
- include_package_data=True,
47
- install_requires=install_requires,
48
- python_requires=">=3.6",
49
- license="MIT",
50
  )
 
1
+ from setuptools import setup, find_packages
2
+ from codecs import open
3
+ from os import path
4
+
5
+ here = path.abspath(path.dirname(__file__))
6
+
7
+ # Get the long description from the README file
8
+ with open(path.join(here, "README.md"), encoding="utf-8") as f:
9
+ long_description = f.read()
10
+
11
+ with open(path.join(here, 'requirements.txt')) as f:
12
+ install_requires = [x for x in f.read().splitlines() if len(x)]
13
+
14
+ exec(open("gcvit/version.py").read())
15
+
16
+ setup(
17
+ name="gcvit",
18
+ version=__version__,
19
+ description="Tensorflow 2.0 Implementation of GCViT: Global Context Vision Transformer. https://github.com/awsaf49/gcvit-tf",
20
+ long_description=long_description,
21
+ long_description_content_type="text/markdown",
22
+ url="https://github.com/awsaf49/gcvit-tf",
23
+ author="Awsaf",
24
+ author_email="[email protected]",
25
+ classifiers=[
26
+ # How mature is this project? Common values are
27
+ # 3 - Alpha
28
+ # 4 - Beta
29
+ # 5 - Production/Stable
30
+ "Development Status :: 3 - Alpha",
31
+ "Intended Audience :: Developers",
32
+ "Intended Audience :: Science/Research",
33
+ "License :: OSI Approved :: Apache Software License",
34
+ "Programming Language :: Python :: 3.6",
35
+ "Programming Language :: Python :: 3.7",
36
+ "Programming Language :: Python :: 3.8",
37
+ "Topic :: Scientific/Engineering",
38
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
39
+ "Topic :: Software Development",
40
+ "Topic :: Software Development :: Libraries",
41
+ "Topic :: Software Development :: Libraries :: Python Modules",
42
+ ],
43
+ # Note that this is a string of words separated by whitespace, not a list.
44
+ keywords="tensorflow computer_vision image classification transformer",
45
+ packages=find_packages(exclude=["tests"]),
46
+ include_package_data=True,
47
+ install_requires=install_requires,
48
+ python_requires=">=3.6",
49
+ license="MIT",
50
  )