ahatamiz commited on
Commit
eb45d3e
·
verified ·
1 Parent(s): b79e36c

Upload model

Browse files
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MambaVisionModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_mambavision.MambaVisionConfig",
7
+ "AutoModel": "modeling_mambavision.MambaVisionModel"
8
+ },
9
+ "depths": [
10
+ 3,
11
+ 3,
12
+ 20,
13
+ 10
14
+ ],
15
+ "dim": 256,
16
+ "drop_path_rate": 0.3,
17
+ "in_dim": 64,
18
+ "layer_scale": 1e-05,
19
+ "layer_scale_conv": null,
20
+ "mlp_ratio": 4,
21
+ "model_type": "mambavision",
22
+ "num_heads": [
23
+ 4,
24
+ 8,
25
+ 16,
26
+ 32
27
+ ],
28
+ "torch_dtype": "float32",
29
+ "transformers_version": "4.47.1",
30
+ "window_size": [
31
+ 8,
32
+ 8,
33
+ 16,
34
+ 8
35
+ ]
36
+ }
configuration_mambavision.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class MambaVisionConfig(PretrainedConfig):
4
+ model_type = "mambavision"
5
+
6
+ def __init__(
7
+ self,
8
+ depths=[3, 3, 20, 10],
9
+ num_heads=[4, 8, 16, 32],
10
+ window_size=[8, 8, 16, 8],
11
+ dim=256,
12
+ in_dim=64,
13
+ mlp_ratio=4,
14
+ drop_path_rate=0.3,
15
+ layer_scale=1e-5,
16
+ layer_scale_conv=None,
17
+ **kwargs,
18
+ ):
19
+ self.depths = depths
20
+ self.num_heads = num_heads
21
+ self.window_size = window_size
22
+ self.dim = dim
23
+ self.in_dim = in_dim
24
+ self.mlp_ratio = mlp_ratio
25
+ self.drop_path_rate = drop_path_rate
26
+ self.layer_scale=layer_scale
27
+ self.layer_scale_conv=layer_scale_conv
28
+ super().__init__(**kwargs)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:22bf14ac8e50e0551facd8279f9225ea08d1f9d0bea14aa0c610b090f0883e2f
3
+ size 2958403688
modeling_mambavision.py ADDED
@@ -0,0 +1,786 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ # and proprietary rights in and to this software, related documentation
7
+ # and any modifications thereto. Any use, reproduction, disclosure or
8
+ # distribution of this software and related documentation without an express
9
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from timm.models.registry import register_model
15
+ import math
16
+ from timm.models.layers import trunc_normal_, DropPath, LayerNorm2d
17
+ from timm.models._builder import resolve_pretrained_cfg
18
+ try:
19
+ from timm.models._builder import _update_default_kwargs as update_args
20
+ except:
21
+ from timm.models._builder import _update_default_model_kwargs as update_args
22
+ from timm.models.vision_transformer import Mlp, PatchEmbed
23
+ from timm.models.layers import DropPath, trunc_normal_
24
+ from timm.models.registry import register_model
25
+ import torch.nn.functional as F
26
+ from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
27
+ from einops import rearrange, repeat
28
+
29
+ from transformers import PreTrainedModel
30
+ try:
31
+ from .configuration_mambavision import MambaVisionConfig
32
+ except:
33
+ from configuration_mambavision import MambaVisionConfig
34
+
35
+
36
+ def _cfg(url='', **kwargs):
37
+ return {'url': url,
38
+ 'num_classes': 1000,
39
+ 'input_size': (3, 224, 224),
40
+ 'pool_size': None,
41
+ 'crop_pct': 0.875,
42
+ 'interpolation': 'bicubic',
43
+ 'fixed_input_size': True,
44
+ 'mean': (0.485, 0.456, 0.406),
45
+ 'std': (0.229, 0.224, 0.225),
46
+ **kwargs
47
+ }
48
+
49
+
50
+ default_cfgs = {
51
+ 'mamba_vision_T': _cfg(url='https://huggingface.co/nvidia/MambaVision-T-1K/resolve/main/mambavision_tiny_1k.pth.tar',
52
+ crop_pct=1.0,
53
+ input_size=(3, 224, 224),
54
+ crop_mode='center'),
55
+ 'mamba_vision_T2': _cfg(url='https://huggingface.co/nvidia/MambaVision-T2-1K/resolve/main/mambavision_tiny2_1k.pth.tar',
56
+ crop_pct=0.98,
57
+ input_size=(3, 224, 224),
58
+ crop_mode='center'),
59
+ 'mamba_vision_S': _cfg(url='https://huggingface.co/nvidia/MambaVision-S-1K/resolve/main/mambavision_small_1k.pth.tar',
60
+ crop_pct=0.93,
61
+ input_size=(3, 224, 224),
62
+ crop_mode='center'),
63
+ 'mamba_vision_B': _cfg(url='https://huggingface.co/nvidia/MambaVision-B-1K/resolve/main/mambavision_base_1k.pth.tar',
64
+ crop_pct=1.0,
65
+ input_size=(3, 224, 224),
66
+ crop_mode='center'),
67
+ 'mamba_vision_B_21k': _cfg(url='https://huggingface.co/nvidia/MambaVision-B-21K/resolve/main/mambavision_base_21k.pth.tar',
68
+ crop_pct=1.0,
69
+ input_size=(3, 224, 224),
70
+ crop_mode='center'),
71
+ 'mamba_vision_L': _cfg(url='https://huggingface.co/nvidia/MambaVision-L-1K/resolve/main/mambavision_large_1k.pth.tar',
72
+ crop_pct=1.0,
73
+ input_size=(3, 224, 224),
74
+ crop_mode='center'),
75
+ 'mamba_vision_L_21k': _cfg(url='https://huggingface.co/nvidia/MambaVision-L-21K/resolve/main/mambavision_large_21k.pth.tar',
76
+ crop_pct=1.0,
77
+ input_size=(3, 224, 224),
78
+ crop_mode='center'),
79
+ 'mamba_vision_L2': _cfg(url='https://huggingface.co/nvidia/MambaVision-L2-1K/resolve/main/mambavision_large2_1k.pth.tar',
80
+ crop_pct=1.0,
81
+ input_size=(3, 224, 224),
82
+ crop_mode='center'),
83
+ 'mamba_vision_L2_512_21k': _cfg(url='https://huggingface.co/nvidia/MambaVision-L2-21K-512/resolve/main/mambavision_L2_21k_240m_512.pth.tar',
84
+ crop_pct=0.93,
85
+ input_size=(3, 512, 512),
86
+ crop_mode='squash'),
87
+ 'mamba_vision_L3_256_21k': _cfg(url='https://huggingface.co/nvidia/MambaVision-L3-21K-256/resolve/main/mambavision_L3_21k_740m_256.pth.tar',
88
+ crop_pct=1.0,
89
+ input_size=(3, 256, 256),
90
+ crop_mode='center'),
91
+ 'mamba_vision_L3_512_21k': _cfg(url='https://huggingface.co/nvidia/MambaVision-L3-21K-512/resolve/main/mambavision_L3_21k_740m_512.pth.tar',
92
+ crop_pct=0.93,
93
+ input_size=(3, 512, 512),
94
+ crop_mode='squash'),
95
+ }
96
+
97
+
98
+ def window_partition(x, window_size):
99
+ """
100
+ Args:
101
+ x: (B, C, H, W)
102
+ window_size: window size
103
+ h_w: Height of window
104
+ w_w: Width of window
105
+ Returns:
106
+ local window features (num_windows*B, window_size*window_size, C)
107
+ """
108
+ B, C, H, W = x.shape
109
+ x = x.view(B, C, H // window_size, window_size, W // window_size, window_size)
110
+ windows = x.permute(0, 2, 4, 3, 5, 1).reshape(-1, window_size*window_size, C)
111
+ return windows
112
+
113
+
114
+ def window_reverse(windows, window_size, H, W):
115
+ """
116
+ Args:
117
+ windows: local window features (num_windows*B, window_size, window_size, C)
118
+ window_size: Window size
119
+ H: Height of image
120
+ W: Width of image
121
+ Returns:
122
+ x: (B, C, H, W)
123
+ """
124
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
125
+ x = windows.reshape(B, H // window_size, W // window_size, window_size, window_size, -1)
126
+ x = x.permute(0, 5, 1, 3, 2, 4).reshape(B,windows.shape[2], H, W)
127
+ return x
128
+
129
+
130
+ def _load_state_dict(module, state_dict, strict=False, logger=None):
131
+ """Load state_dict to a module.
132
+
133
+ This method is modified from :meth:`torch.nn.Module.load_state_dict`.
134
+ Default value for ``strict`` is set to ``False`` and the message for
135
+ param mismatch will be shown even if strict is False.
136
+
137
+ Args:
138
+ module (Module): Module that receives the state_dict.
139
+ state_dict (OrderedDict): Weights.
140
+ strict (bool): whether to strictly enforce that the keys
141
+ in :attr:`state_dict` match the keys returned by this module's
142
+ :meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
143
+ logger (:obj:`logging.Logger`, optional): Logger to log the error
144
+ message. If not specified, print function will be used.
145
+ """
146
+ unexpected_keys = []
147
+ all_missing_keys = []
148
+ err_msg = []
149
+
150
+ metadata = getattr(state_dict, '_metadata', None)
151
+ state_dict = state_dict.copy()
152
+ if metadata is not None:
153
+ state_dict._metadata = metadata
154
+
155
+ def load(module, prefix=''):
156
+ local_metadata = {} if metadata is None else metadata.get(
157
+ prefix[:-1], {})
158
+ module._load_from_state_dict(state_dict, prefix, local_metadata, True,
159
+ all_missing_keys, unexpected_keys,
160
+ err_msg)
161
+ for name, child in module._modules.items():
162
+ if child is not None:
163
+ load(child, prefix + name + '.')
164
+
165
+ load(module)
166
+ load = None
167
+ missing_keys = [
168
+ key for key in all_missing_keys if 'num_batches_tracked' not in key
169
+ ]
170
+
171
+ if unexpected_keys:
172
+ err_msg.append('unexpected key in source '
173
+ f'state_dict: {", ".join(unexpected_keys)}\n')
174
+ if missing_keys:
175
+ err_msg.append(
176
+ f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
177
+
178
+
179
+ if len(err_msg) > 0:
180
+ err_msg.insert(
181
+ 0, 'The model and loaded state dict do not match exactly\n')
182
+ err_msg = '\n'.join(err_msg)
183
+ if strict:
184
+ raise RuntimeError(err_msg)
185
+ elif logger is not None:
186
+ logger.warning(err_msg)
187
+ else:
188
+ print(err_msg)
189
+
190
+
191
+ def _load_checkpoint(model,
192
+ filename,
193
+ map_location='cpu',
194
+ strict=False,
195
+ logger=None):
196
+ """Load checkpoint from a file or URI.
197
+
198
+ Args:
199
+ model (Module): Module to load checkpoint.
200
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
201
+ ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
202
+ details.
203
+ map_location (str): Same as :func:`torch.load`.
204
+ strict (bool): Whether to allow different params for the model and
205
+ checkpoint.
206
+ logger (:mod:`logging.Logger` or None): The logger for error message.
207
+
208
+ Returns:
209
+ dict or OrderedDict: The loaded checkpoint.
210
+ """
211
+ checkpoint = torch.load(filename, map_location=map_location)
212
+ if not isinstance(checkpoint, dict):
213
+ raise RuntimeError(
214
+ f'No state_dict found in checkpoint file {filename}')
215
+ if 'state_dict' in checkpoint:
216
+ state_dict = checkpoint['state_dict']
217
+ elif 'model' in checkpoint:
218
+ state_dict = checkpoint['model']
219
+ else:
220
+ state_dict = checkpoint
221
+ if list(state_dict.keys())[0].startswith('module.'):
222
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
223
+
224
+ if sorted(list(state_dict.keys()))[0].startswith('encoder'):
225
+ state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
226
+
227
+ _load_state_dict(model, state_dict, strict, logger)
228
+ return checkpoint
229
+
230
+
231
+ class Downsample(nn.Module):
232
+ """
233
+ Down-sampling block"
234
+ """
235
+
236
+ def __init__(self,
237
+ dim,
238
+ keep_dim=False,
239
+ ):
240
+ """
241
+ Args:
242
+ dim: feature size dimension.
243
+ norm_layer: normalization layer.
244
+ keep_dim: bool argument for maintaining the resolution.
245
+ """
246
+
247
+ super().__init__()
248
+ if keep_dim:
249
+ dim_out = dim
250
+ else:
251
+ dim_out = 2 * dim
252
+ self.reduction = nn.Sequential(
253
+ nn.Conv2d(dim, dim_out, 3, 2, 1, bias=False),
254
+ )
255
+
256
+ def forward(self, x):
257
+ x = self.reduction(x)
258
+ return x
259
+
260
+
261
+ class PatchEmbed(nn.Module):
262
+ """
263
+ Patch embedding block"
264
+ """
265
+
266
+ def __init__(self, in_chans=3, in_dim=64, dim=96):
267
+ """
268
+ Args:
269
+ in_chans: number of input channels.
270
+ dim: feature size dimension.
271
+ """
272
+ # in_dim = 1
273
+ super().__init__()
274
+ self.proj = nn.Identity()
275
+ self.conv_down = nn.Sequential(
276
+ nn.Conv2d(in_chans, in_dim, 3, 2, 1, bias=False),
277
+ nn.BatchNorm2d(in_dim, eps=1e-4),
278
+ nn.ReLU(),
279
+ nn.Conv2d(in_dim, dim, 3, 2, 1, bias=False),
280
+ nn.BatchNorm2d(dim, eps=1e-4),
281
+ nn.ReLU()
282
+ )
283
+
284
+ def forward(self, x):
285
+ x = self.proj(x)
286
+ x = self.conv_down(x)
287
+ return x
288
+
289
+
290
+ class ConvBlock(nn.Module):
291
+
292
+ def __init__(self, dim,
293
+ drop_path=0.,
294
+ layer_scale=None,
295
+ kernel_size=3):
296
+ super().__init__()
297
+
298
+ self.conv1 = nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
299
+ self.norm1 = nn.BatchNorm2d(dim, eps=1e-5)
300
+ self.act1 = nn.GELU(approximate= 'tanh')
301
+ self.conv2 = nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
302
+ self.norm2 = nn.BatchNorm2d(dim, eps=1e-5)
303
+ self.layer_scale = layer_scale
304
+ if layer_scale is not None and type(layer_scale) in [int, float]:
305
+ self.g = nn.Parameter(layer_scale * torch.ones(dim))
306
+ self.layer_scale = True
307
+ else:
308
+ self.layer_scale = False
309
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
310
+
311
+ def forward(self, x):
312
+ input = x
313
+ x = self.conv1(x)
314
+ x = self.norm1(x)
315
+ x = self.act1(x)
316
+ x = self.conv2(x)
317
+ x = self.norm2(x)
318
+ if self.layer_scale:
319
+ x = x * self.g.view(1, -1, 1, 1)
320
+ x = input + self.drop_path(x)
321
+ return x
322
+
323
+
324
+ class MambaVisionMixer(nn.Module):
325
+ def __init__(
326
+ self,
327
+ d_model,
328
+ d_state=16,
329
+ d_conv=4,
330
+ expand=2,
331
+ dt_rank="auto",
332
+ dt_min=0.001,
333
+ dt_max=0.1,
334
+ dt_init="random",
335
+ dt_scale=1.0,
336
+ dt_init_floor=1e-4,
337
+ conv_bias=True,
338
+ bias=False,
339
+ use_fast_path=True,
340
+ layer_idx=None,
341
+ device=None,
342
+ dtype=None,
343
+ ):
344
+ factory_kwargs = {"device": device, "dtype": dtype}
345
+ super().__init__()
346
+ self.d_model = d_model
347
+ self.d_state = d_state
348
+ self.d_conv = d_conv
349
+ self.expand = expand
350
+ self.d_inner = int(self.expand * self.d_model)
351
+ self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
352
+ self.use_fast_path = use_fast_path
353
+ self.layer_idx = layer_idx
354
+ self.in_proj = nn.Linear(self.d_model, self.d_inner, bias=bias, **factory_kwargs)
355
+ self.x_proj = nn.Linear(
356
+ self.d_inner//2, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
357
+ )
358
+ self.dt_proj = nn.Linear(self.dt_rank, self.d_inner//2, bias=True, **factory_kwargs)
359
+ dt_init_std = self.dt_rank**-0.5 * dt_scale
360
+ if dt_init == "constant":
361
+ nn.init.constant_(self.dt_proj.weight, dt_init_std)
362
+ elif dt_init == "random":
363
+ nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
364
+ else:
365
+ raise NotImplementedError
366
+ dt = torch.exp(
367
+ torch.rand(self.d_inner//2, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
368
+ + math.log(dt_min)
369
+ ).clamp(min=dt_init_floor)
370
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
371
+ with torch.no_grad():
372
+ self.dt_proj.bias.copy_(inv_dt)
373
+ self.dt_proj.bias._no_reinit = True
374
+ A = repeat(
375
+ torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
376
+ "n -> d n",
377
+ d=self.d_inner//2,
378
+ ).contiguous()
379
+ A_log = torch.log(A)
380
+ self.A_log = nn.Parameter(A_log)
381
+ self.A_log._no_weight_decay = True
382
+ self.D = nn.Parameter(torch.ones(self.d_inner//2, device=device))
383
+ self.D._no_weight_decay = True
384
+ self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
385
+ self.conv1d_x = nn.Conv1d(
386
+ in_channels=self.d_inner//2,
387
+ out_channels=self.d_inner//2,
388
+ bias=conv_bias//2,
389
+ kernel_size=d_conv,
390
+ groups=self.d_inner//2,
391
+ **factory_kwargs,
392
+ )
393
+ self.conv1d_z = nn.Conv1d(
394
+ in_channels=self.d_inner//2,
395
+ out_channels=self.d_inner//2,
396
+ bias=conv_bias//2,
397
+ kernel_size=d_conv,
398
+ groups=self.d_inner//2,
399
+ **factory_kwargs,
400
+ )
401
+
402
+ def forward(self, hidden_states):
403
+ """
404
+ hidden_states: (B, L, D)
405
+ Returns: same shape as hidden_states
406
+ """
407
+ _, seqlen, _ = hidden_states.shape
408
+ xz = self.in_proj(hidden_states)
409
+ xz = rearrange(xz, "b l d -> b d l")
410
+ x, z = xz.chunk(2, dim=1)
411
+ A = -torch.exp(self.A_log.float())
412
+ x = F.silu(F.conv1d(input=x, weight=self.conv1d_x.weight, bias=self.conv1d_x.bias, padding='same', groups=self.d_inner//2))
413
+ z = F.silu(F.conv1d(input=z, weight=self.conv1d_z.weight, bias=self.conv1d_z.bias, padding='same', groups=self.d_inner//2))
414
+ x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d"))
415
+ dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
416
+ dt = rearrange(self.dt_proj(dt), "(b l) d -> b d l", l=seqlen)
417
+ B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
418
+ C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
419
+ y = selective_scan_fn(x,
420
+ dt,
421
+ A,
422
+ B,
423
+ C,
424
+ self.D.float(),
425
+ z=None,
426
+ delta_bias=self.dt_proj.bias.float(),
427
+ delta_softplus=True,
428
+ return_last_state=None)
429
+
430
+ y = torch.cat([y, z], dim=1)
431
+ y = rearrange(y, "b d l -> b l d")
432
+ out = self.out_proj(y)
433
+ return out
434
+
435
+
436
+ class Attention(nn.Module):
437
+
438
+ def __init__(
439
+ self,
440
+ dim,
441
+ num_heads=8,
442
+ qkv_bias=False,
443
+ qk_norm=False,
444
+ attn_drop=0.,
445
+ proj_drop=0.,
446
+ norm_layer=nn.LayerNorm,
447
+ ):
448
+ super().__init__()
449
+ assert dim % num_heads == 0
450
+ self.num_heads = num_heads
451
+ self.head_dim = dim // num_heads
452
+ self.scale = self.head_dim ** -0.5
453
+ self.fused_attn = True
454
+
455
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
456
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
457
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
458
+ self.attn_drop = nn.Dropout(attn_drop)
459
+ self.proj = nn.Linear(dim, dim)
460
+ self.proj_drop = nn.Dropout(proj_drop)
461
+
462
+ def forward(self, x):
463
+ B, N, C = x.shape
464
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
465
+ q, k, v = qkv.unbind(0)
466
+ q, k = self.q_norm(q), self.k_norm(k)
467
+
468
+ if self.fused_attn:
469
+ x = F.scaled_dot_product_attention(
470
+ q, k, v,
471
+ dropout_p=self.attn_drop.p,
472
+ )
473
+ else:
474
+ q = q * self.scale
475
+ attn = q @ k.transpose(-2, -1)
476
+ attn = attn.softmax(dim=-1)
477
+ attn = self.attn_drop(attn)
478
+ x = attn @ v
479
+
480
+ x = x.transpose(1, 2).reshape(B, N, C)
481
+ x = self.proj(x)
482
+ x = self.proj_drop(x)
483
+ return x
484
+
485
+
486
+ class Block(nn.Module):
487
+ def __init__(self,
488
+ dim,
489
+ num_heads,
490
+ counter,
491
+ transformer_blocks,
492
+ mlp_ratio=4.,
493
+ qkv_bias=False,
494
+ qk_scale=False,
495
+ drop=0.,
496
+ attn_drop=0.,
497
+ drop_path=0.,
498
+ act_layer=nn.GELU,
499
+ norm_layer=nn.LayerNorm,
500
+ Mlp_block=Mlp,
501
+ layer_scale=None,
502
+ ):
503
+ super().__init__()
504
+ self.norm1 = norm_layer(dim)
505
+ if counter in transformer_blocks:
506
+ self.mixer = Attention(
507
+ dim,
508
+ num_heads=num_heads,
509
+ qkv_bias=qkv_bias,
510
+ qk_norm=qk_scale,
511
+ attn_drop=attn_drop,
512
+ proj_drop=drop,
513
+ norm_layer=norm_layer,
514
+ )
515
+ else:
516
+ self.mixer = MambaVisionMixer(d_model=dim,
517
+ d_state=8,
518
+ d_conv=3,
519
+ expand=1
520
+ )
521
+
522
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
523
+ self.norm2 = norm_layer(dim)
524
+ mlp_hidden_dim = int(dim * mlp_ratio)
525
+ self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
526
+ use_layer_scale = layer_scale is not None and type(layer_scale) in [int, float]
527
+ self.g_1 = nn.Parameter(layer_scale * torch.ones(dim)) if use_layer_scale else 1
528
+ self.g_2 = nn.Parameter(layer_scale * torch.ones(dim)) if use_layer_scale else 1
529
+
530
+ def forward(self, x):
531
+ x = x + self.drop_path(self.g_1 * self.mixer(self.norm1(x)))
532
+ x = x + self.drop_path(self.g_2 * self.mlp(self.norm2(x)))
533
+ return x
534
+
535
+
536
+ class MambaVisionLayer(nn.Module):
537
+ """
538
+ MambaVision layer"
539
+ """
540
+
541
+ def __init__(self,
542
+ dim,
543
+ depth,
544
+ num_heads,
545
+ window_size,
546
+ conv=False,
547
+ downsample=True,
548
+ mlp_ratio=4.,
549
+ qkv_bias=True,
550
+ qk_scale=None,
551
+ drop=0.,
552
+ attn_drop=0.,
553
+ drop_path=0.,
554
+ layer_scale=None,
555
+ layer_scale_conv=None,
556
+ transformer_blocks = [],
557
+ ):
558
+ """
559
+ Args:
560
+ dim: feature size dimension.
561
+ depth: number of layers in each stage.
562
+ window_size: window size in each stage.
563
+ conv: bool argument for conv stage flag.
564
+ downsample: bool argument for down-sampling.
565
+ mlp_ratio: MLP ratio.
566
+ num_heads: number of heads in each stage.
567
+ qkv_bias: bool argument for query, key, value learnable bias.
568
+ qk_scale: bool argument to scaling query, key.
569
+ drop: dropout rate.
570
+ attn_drop: attention dropout rate.
571
+ drop_path: drop path rate.
572
+ norm_layer: normalization layer.
573
+ layer_scale: layer scaling coefficient.
574
+ layer_scale_conv: conv layer scaling coefficient.
575
+ transformer_blocks: list of transformer blocks.
576
+ """
577
+
578
+ super().__init__()
579
+ self.conv = conv
580
+ self.transformer_block = False
581
+ if conv:
582
+ self.blocks = nn.ModuleList([ConvBlock(dim=dim,
583
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
584
+ layer_scale=layer_scale_conv)
585
+ for i in range(depth)])
586
+ self.transformer_block = False
587
+ else:
588
+ self.transformer_block = True
589
+ self.blocks = nn.ModuleList([Block(dim=dim,
590
+ counter=i,
591
+ transformer_blocks=transformer_blocks,
592
+ num_heads=num_heads,
593
+ mlp_ratio=mlp_ratio,
594
+ qkv_bias=qkv_bias,
595
+ qk_scale=qk_scale,
596
+ drop=drop,
597
+ attn_drop=attn_drop,
598
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
599
+ layer_scale=layer_scale)
600
+ for i in range(depth)])
601
+ self.transformer_block = True
602
+
603
+ self.downsample = None if not downsample else Downsample(dim=dim)
604
+ self.do_gt = False
605
+ self.window_size = window_size
606
+
607
+ def forward(self, x):
608
+ _, _, H, W = x.shape
609
+
610
+ if self.transformer_block:
611
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
612
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
613
+ if pad_r > 0 or pad_b > 0:
614
+ x = torch.nn.functional.pad(x, (0,pad_r,0,pad_b))
615
+ _, _, Hp, Wp = x.shape
616
+ else:
617
+ Hp, Wp = H, W
618
+ x = window_partition(x, self.window_size)
619
+
620
+ for _, blk in enumerate(self.blocks):
621
+ x = blk(x)
622
+ if self.transformer_block:
623
+ x = window_reverse(x, self.window_size, Hp, Wp)
624
+ if pad_r > 0 or pad_b > 0:
625
+ x = x[:, :, :H, :W].contiguous()
626
+ if self.downsample is None:
627
+ return x, x
628
+ return self.downsample(x), x
629
+
630
+
631
+ class MambaVision(nn.Module):
632
+ """
633
+ MambaVision,
634
+ """
635
+
636
+ def __init__(self,
637
+ dim,
638
+ in_dim,
639
+ depths,
640
+ window_size,
641
+ mlp_ratio,
642
+ num_heads,
643
+ drop_path_rate=0.2,
644
+ in_chans=3,
645
+ num_classes=1000,
646
+ qkv_bias=True,
647
+ qk_scale=None,
648
+ drop_rate=0.,
649
+ attn_drop_rate=0.,
650
+ layer_scale=None,
651
+ layer_scale_conv=None,
652
+ **kwargs):
653
+ """
654
+ Args:
655
+ dim: feature size dimension.
656
+ depths: number of layers in each stage.
657
+ window_size: window size in each stage.
658
+ mlp_ratio: MLP ratio.
659
+ num_heads: number of heads in each stage.
660
+ drop_path_rate: drop path rate.
661
+ in_chans: number of input channels.
662
+ num_classes: number of classes.
663
+ qkv_bias: bool argument for query, key, value learnable bias.
664
+ qk_scale: bool argument to scaling query, key.
665
+ drop_rate: dropout rate.
666
+ attn_drop_rate: attention dropout rate.
667
+ norm_layer: normalization layer.
668
+ layer_scale: layer scaling coefficient.
669
+ layer_scale_conv: conv layer scaling coefficient.
670
+ """
671
+ super().__init__()
672
+ num_features = int(dim * 2 ** (len(depths) - 1))
673
+ self.num_classes = num_classes
674
+ self.patch_embed = PatchEmbed(in_chans=in_chans, in_dim=in_dim, dim=dim)
675
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
676
+ self.levels = nn.ModuleList()
677
+ for i in range(len(depths)):
678
+ conv = True if (i == 0 or i == 1) else False
679
+ level = MambaVisionLayer(dim=int(dim * 2 ** i),
680
+ depth=depths[i],
681
+ num_heads=num_heads[i],
682
+ window_size=window_size[i],
683
+ mlp_ratio=mlp_ratio,
684
+ qkv_bias=qkv_bias,
685
+ qk_scale=qk_scale,
686
+ conv=conv,
687
+ drop=drop_rate,
688
+ attn_drop=attn_drop_rate,
689
+ drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
690
+ downsample=(i < 3),
691
+ layer_scale=layer_scale,
692
+ layer_scale_conv=layer_scale_conv,
693
+ transformer_blocks=list(range(depths[i]//2+1, depths[i])) if depths[i]%2!=0 else list(range(depths[i]//2, depths[i])),
694
+ )
695
+ self.levels.append(level)
696
+ self.norm = nn.BatchNorm2d(num_features)
697
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
698
+ self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()
699
+ self.apply(self._init_weights)
700
+
701
+ def _init_weights(self, m):
702
+ if isinstance(m, nn.Linear):
703
+ trunc_normal_(m.weight, std=.02)
704
+ if isinstance(m, nn.Linear) and m.bias is not None:
705
+ nn.init.constant_(m.bias, 0)
706
+ elif isinstance(m, nn.LayerNorm):
707
+ nn.init.constant_(m.bias, 0)
708
+ nn.init.constant_(m.weight, 1.0)
709
+ elif isinstance(m, LayerNorm2d):
710
+ nn.init.constant_(m.bias, 0)
711
+ nn.init.constant_(m.weight, 1.0)
712
+ elif isinstance(m, nn.BatchNorm2d):
713
+ nn.init.ones_(m.weight)
714
+ nn.init.zeros_(m.bias)
715
+
716
+ @torch.jit.ignore
717
+ def no_weight_decay_keywords(self):
718
+ return {'rpb'}
719
+
720
+ def forward_features(self, x):
721
+ x = self.patch_embed(x)
722
+ outs = []
723
+ for level in self.levels:
724
+ x, xo = level(x)
725
+ outs.append(xo)
726
+ x = self.norm(x)
727
+ x = self.avgpool(x)
728
+ x = torch.flatten(x, 1)
729
+ return x, outs
730
+
731
+ def forward(self, x):
732
+ x, outs = self.forward_features(x)
733
+ x = self.head(x)
734
+ return x
735
+
736
+ def _load_state_dict(self,
737
+ pretrained,
738
+ strict: bool = False):
739
+ _load_checkpoint(self,
740
+ pretrained,
741
+ strict=strict)
742
+
743
+
744
+ class MambaVisionModel(PreTrainedModel):
745
+ config_class = MambaVisionConfig
746
+
747
+ def __init__(self, config):
748
+ super().__init__(config)
749
+ self.model = MambaVision(
750
+ depths=config.depths,
751
+ num_heads=config.num_heads,
752
+ window_size=config.window_size,
753
+ dim=config.dim,
754
+ in_dim=config.in_dim,
755
+ mlp_ratio=config.mlp_ratio,
756
+ layer_scale=config.layer_scale,
757
+ layer_scale_conv=config.layer_scale_conv
758
+ )
759
+
760
+ def forward(self, tensor):
761
+ return self.model.forward_features(tensor)
762
+
763
+
764
+ class MambaVisionModelForImageClassification(PreTrainedModel):
765
+ config_class = MambaVisionConfig
766
+
767
+
768
+ def __init__(self, config):
769
+ super().__init__(config)
770
+ self.model = MambaVision(
771
+ depths=config.depths,
772
+ num_heads=config.num_heads,
773
+ window_size=config.window_size,
774
+ dim=config.dim,
775
+ in_dim=config.in_dim,
776
+ mlp_ratio=config.mlp_ratio,
777
+ layer_scale=config.layer_scale,
778
+ layer_scale_conv=config.layer_scale_conv
779
+ )
780
+
781
+ def forward(self, tensor, labels=None):
782
+ logits = self.model(tensor)
783
+ if labels is not None:
784
+ loss = torch.nn.cross_entropy(logits, labels)
785
+ return {"loss": loss, "logits": logits}
786
+ return {"logits": logits}