Update README.md
Browse files
README.md
CHANGED
@@ -1,199 +1,153 @@
|
|
1 |
---
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
4 |
---
|
5 |
|
6 |
-
# Model Card for Model ID
|
7 |
|
8 |
-
|
9 |
|
|
|
10 |
|
|
|
11 |
|
12 |
-
## Model
|
13 |
|
14 |
-
|
15 |
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
-
- **Developed by:** [More Information Needed]
|
21 |
-
- **Funded by [optional]:** [More Information Needed]
|
22 |
-
- **Shared by [optional]:** [More Information Needed]
|
23 |
-
- **Model type:** [More Information Needed]
|
24 |
-
- **Language(s) (NLP):** [More Information Needed]
|
25 |
-
- **License:** [More Information Needed]
|
26 |
-
- **Finetuned from model [optional]:** [More Information Needed]
|
27 |
|
28 |
-
|
29 |
|
30 |
-
|
|
|
31 |
|
32 |
-
|
33 |
-
-
|
34 |
-
|
|
|
35 |
|
36 |
-
## Uses
|
37 |
|
38 |
-
|
39 |
|
40 |
-
|
41 |
|
42 |
-
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
43 |
|
44 |
-
|
|
|
|
|
45 |
|
46 |
-
|
47 |
|
48 |
-
|
49 |
|
50 |
-
|
51 |
|
52 |
-
|
53 |
|
54 |
-
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
55 |
|
56 |
-
|
|
|
|
|
|
|
57 |
|
58 |
-
## Bias, Risks, and Limitations
|
59 |
|
60 |
-
|
61 |
|
62 |
-
|
|
|
|
|
|
|
|
|
63 |
|
64 |
-
|
65 |
|
66 |
-
|
|
|
67 |
|
68 |
-
|
|
|
|
|
|
|
69 |
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
-
|
75 |
|
76 |
-
|
77 |
|
78 |
-
|
79 |
|
80 |
-
|
81 |
|
82 |
-
|
83 |
|
84 |
-
|
|
|
|
|
|
|
|
|
85 |
|
86 |
-
|
87 |
|
88 |
-
|
|
|
89 |
|
90 |
-
|
|
|
|
|
|
|
91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
-
#### Training Hyperparameters
|
94 |
|
95 |
-
|
96 |
|
97 |
-
|
98 |
-
|
99 |
-
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
100 |
-
|
101 |
-
[More Information Needed]
|
102 |
-
|
103 |
-
## Evaluation
|
104 |
-
|
105 |
-
<!-- This section describes the evaluation protocols and provides the results. -->
|
106 |
-
|
107 |
-
### Testing Data, Factors & Metrics
|
108 |
-
|
109 |
-
#### Testing Data
|
110 |
-
|
111 |
-
<!-- This should link to a Dataset Card if possible. -->
|
112 |
-
|
113 |
-
[More Information Needed]
|
114 |
-
|
115 |
-
#### Factors
|
116 |
-
|
117 |
-
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
118 |
-
|
119 |
-
[More Information Needed]
|
120 |
-
|
121 |
-
#### Metrics
|
122 |
-
|
123 |
-
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
124 |
-
|
125 |
-
[More Information Needed]
|
126 |
-
|
127 |
-
### Results
|
128 |
-
|
129 |
-
[More Information Needed]
|
130 |
-
|
131 |
-
#### Summary
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
## Model Examination [optional]
|
136 |
-
|
137 |
-
<!-- Relevant interpretability work for the model goes here -->
|
138 |
-
|
139 |
-
[More Information Needed]
|
140 |
-
|
141 |
-
## Environmental Impact
|
142 |
-
|
143 |
-
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
144 |
-
|
145 |
-
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
146 |
-
|
147 |
-
- **Hardware Type:** [More Information Needed]
|
148 |
-
- **Hours used:** [More Information Needed]
|
149 |
-
- **Cloud Provider:** [More Information Needed]
|
150 |
-
- **Compute Region:** [More Information Needed]
|
151 |
-
- **Carbon Emitted:** [More Information Needed]
|
152 |
-
|
153 |
-
## Technical Specifications [optional]
|
154 |
-
|
155 |
-
### Model Architecture and Objective
|
156 |
-
|
157 |
-
[More Information Needed]
|
158 |
-
|
159 |
-
### Compute Infrastructure
|
160 |
-
|
161 |
-
[More Information Needed]
|
162 |
-
|
163 |
-
#### Hardware
|
164 |
-
|
165 |
-
[More Information Needed]
|
166 |
-
|
167 |
-
#### Software
|
168 |
-
|
169 |
-
[More Information Needed]
|
170 |
-
|
171 |
-
## Citation [optional]
|
172 |
-
|
173 |
-
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
174 |
-
|
175 |
-
**BibTeX:**
|
176 |
-
|
177 |
-
[More Information Needed]
|
178 |
-
|
179 |
-
**APA:**
|
180 |
-
|
181 |
-
[More Information Needed]
|
182 |
-
|
183 |
-
## Glossary [optional]
|
184 |
-
|
185 |
-
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
186 |
-
|
187 |
-
[More Information Needed]
|
188 |
-
|
189 |
-
## More Information [optional]
|
190 |
-
|
191 |
-
[More Information Needed]
|
192 |
-
|
193 |
-
## Model Card Authors [optional]
|
194 |
-
|
195 |
-
[More Information Needed]
|
196 |
-
|
197 |
-
## Model Card Contact
|
198 |
-
|
199 |
-
[More Information Needed]
|
|
|
1 |
---
|
2 |
+
license: other
|
3 |
+
license_name: nvclv1
|
4 |
+
license_link: LICENSE
|
5 |
+
datasets:
|
6 |
+
- ILSVRC/imagenet-21k
|
7 |
+
pipeline_tag: image-feature-extraction
|
8 |
---
|
9 |
|
|
|
10 |
|
11 |
+
[**MambaVision: A Hybrid Mamba-Transformer Vision Backbone**](https://arxiv.org/abs/2407.08083).
|
12 |
|
13 |
+
## Model Overview
|
14 |
|
15 |
+
We have developed the first hybrid model for computer vision which leverages the strengths of Mamba and Transformers. Specifically, our core contribution includes redesigning the Mamba formulation to enhance its capability for efficient modeling of visual features. In addition, we conducted a comprehensive ablation study on the feasibility of integrating Vision Transformers (ViT) with Mamba. Our results demonstrate that equipping the Mamba architecture with several self-attention blocks at the final layers greatly improves the modeling capacity to capture long-range spatial dependencies. Based on our findings, we introduce a family of MambaVision models with a hierarchical architecture to meet various design criteria.
|
16 |
|
17 |
+
## Model Performance
|
18 |
|
19 |
+
MambaVision-L3-256-21K is pretrained on ImageNet-21K dataset and finetuned on ImageNet-1K. Both pretraining and finetuning are performed at 256 x 256 resolution.
|
20 |
|
21 |
+
<table>
|
22 |
+
<tr>
|
23 |
+
<th>Name</th>
|
24 |
+
<th>Acc@1(%)</th>
|
25 |
+
<th>Acc@5(%)</th>
|
26 |
+
<th>#Params(M)</th>
|
27 |
+
<th>FLOPs(G)</th>
|
28 |
+
<th>Resolution</th>
|
29 |
+
</tr>
|
30 |
|
31 |
+
<tr>
|
32 |
+
<td>MambaVision-L3-256-21K</td>
|
33 |
+
<td>87.3</td>
|
34 |
+
<td>98.3</td>
|
35 |
+
<td>739.6</td>
|
36 |
+
<td>122.3</td>
|
37 |
+
<td>256x256</td>
|
38 |
+
</tr>
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
+
</table>
|
42 |
|
43 |
+
In addition, the MambaVision models demonstrate a strong performance by achieving a new SOTA Pareto-front in
|
44 |
+
terms of Top-1 accuracy and throughput.
|
45 |
|
46 |
+
<p align="center">
|
47 |
+
<img src="https://github.com/NVlabs/MambaVision/assets/26806394/79dcf841-3966-4b77-883d-76cd5e1d4320" width=70% height=70%
|
48 |
+
class="center">
|
49 |
+
</p>
|
50 |
|
|
|
51 |
|
52 |
+
## Model Usage
|
53 |
|
54 |
+
It is highly recommended to install the requirements for MambaVision by running the following:
|
55 |
|
|
|
56 |
|
57 |
+
```Bash
|
58 |
+
pip install mambavision
|
59 |
+
```
|
60 |
|
61 |
+
For each model, we offer two variants for image classification and feature extraction that can be imported with 1 line of code.
|
62 |
|
63 |
+
### Image Classification
|
64 |
|
65 |
+
In the following example, we demonstrate how MambaVision can be used for image classification.
|
66 |
|
67 |
+
Given the following image from [COCO dataset](https://cocodataset.org/#home) val set as an input:
|
68 |
|
|
|
69 |
|
70 |
+
<p align="center">
|
71 |
+
<img src="https://cdn-uploads.huggingface.co/production/uploads/64414b62603214724ebd2636/4duSnqLf4lrNiAHczSmAN.jpeg" width=70% height=70%
|
72 |
+
class="center">
|
73 |
+
</p>
|
74 |
|
|
|
75 |
|
76 |
+
The following snippet can be used for image classification:
|
77 |
|
78 |
+
```Python
|
79 |
+
from transformers import AutoModelForImageClassification
|
80 |
+
from PIL import Image
|
81 |
+
from timm.data.transforms_factory import create_transform
|
82 |
+
import requests
|
83 |
|
84 |
+
model = AutoModelForImageClassification.from_pretrained("nvidia/MambaVision-L3-256-21K", trust_remote_code=True)
|
85 |
|
86 |
+
# eval mode for inference
|
87 |
+
model.cuda().eval()
|
88 |
|
89 |
+
# prepare image for the model
|
90 |
+
url = 'http://images.cocodataset.org/val2017/000000020247.jpg'
|
91 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
92 |
+
input_resolution = (3, 256, 256) # MambaVision supports any input resolutions
|
93 |
|
94 |
+
transform = create_transform(input_size=input_resolution,
|
95 |
+
is_training=False,
|
96 |
+
mean=model.config.mean,
|
97 |
+
std=model.config.std,
|
98 |
+
crop_mode=model.config.crop_mode,
|
99 |
+
crop_pct=model.config.crop_pct)
|
100 |
|
101 |
+
inputs = transform(image).unsqueeze(0).cuda()
|
102 |
+
# model inference
|
103 |
+
outputs = model(inputs)
|
104 |
+
logits = outputs['logits']
|
105 |
+
predicted_class_idx = logits.argmax(-1).item()
|
106 |
+
print("Predicted class:", model.config.id2label[predicted_class_idx])
|
107 |
+
```
|
108 |
|
109 |
+
The predicted label is ```brown bear, bruin, Ursus arctos.```
|
110 |
|
111 |
+
### Feature Extraction
|
112 |
|
113 |
+
MambaVision can also be used as a generic feature extractor.
|
114 |
|
115 |
+
Specifically, we can extract the outputs of each stage of model (4 stages) as well as the final averaged-pool features that are flattened.
|
116 |
|
117 |
+
The following snippet can be used for feature extraction:
|
118 |
|
119 |
+
```Python
|
120 |
+
from transformers import AutoModel
|
121 |
+
from PIL import Image
|
122 |
+
from timm.data.transforms_factory import create_transform
|
123 |
+
import requests
|
124 |
|
125 |
+
model = AutoModel.from_pretrained("nvidia/MambaVision-L3-256-21K", trust_remote_code=True)
|
126 |
|
127 |
+
# eval mode for inference
|
128 |
+
model.cuda().eval()
|
129 |
|
130 |
+
# prepare image for the model
|
131 |
+
url = 'http://images.cocodataset.org/val2017/000000020247.jpg'
|
132 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
133 |
+
input_resolution = (3, 256, 256) # MambaVision supports any input resolutions
|
134 |
|
135 |
+
transform = create_transform(input_size=input_resolution,
|
136 |
+
is_training=False,
|
137 |
+
mean=model.config.mean,
|
138 |
+
std=model.config.std,
|
139 |
+
crop_mode=model.config.crop_mode,
|
140 |
+
crop_pct=model.config.crop_pct)
|
141 |
+
inputs = transform(image).unsqueeze(0).cuda()
|
142 |
+
# model inference
|
143 |
+
out_avg_pool, features = model(inputs)
|
144 |
+
print("Size of the averaged pool features:", out_avg_pool.size()) # torch.Size([1, 1568])
|
145 |
+
print("Number of stages in extracted features:", len(features)) # 4 stages
|
146 |
+
print("Size of extracted features in stage 1:", features[0].size()) # torch.Size([1, 196, 128, 128])
|
147 |
+
print("Size of extracted features in stage 4:", features[3].size()) # torch.Size([1, 1568, 16, 16])
|
148 |
+
```
|
149 |
|
|
|
150 |
|
151 |
+
### License:
|
152 |
|
153 |
+
[NVIDIA Source Code License-NC](https://huggingface.co/nvidia/MambaVision-L3-256-21K/blob/main/LICENSE)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|