File size: 3,174 Bytes
a4ef7a6
 
 
 
c3e0fca
a4ef7a6
 
 
 
 
 
d88de45
a4ef7a6
31ea221
a4ef7a6
 
 
31ea221
a4ef7a6
d88de45
 
a4ef7a6
 
 
d88de45
a4ef7a6
 
 
 
 
 
 
 
 
b781b14
a4ef7a6
 
b781b14
a4ef7a6
 
b781b14
a4ef7a6
d88de45
a4ef7a6
b781b14
a4ef7a6
 
 
 
 
 
 
 
b781b14
a4ef7a6
 
b781b14
a4ef7a6
 
b781b14
a4ef7a6
 
 
b781b14
a4ef7a6
 
 
 
 
 
 
 
b781b14
a4ef7a6
 
b781b14
a4ef7a6
 
b781b14
a4ef7a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b781b14
a4ef7a6
 
 
 
 
b781b14
a4ef7a6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
---
tags:
- image-classification
- birder
- pytorch
library_name: birder
license: apache-2.0
---

# Model Card for mvit_v2_t_il-all

A MViTv2 image classification model. This model was trained on the `il-all` dataset, encompassing all relevant bird species found in Israel, including rarities.

The species list is derived from data available at <https://www.israbirding.com/checklist/>.

## Model Details

- **Model Type:** Image classification and detection backbone
- **Model Stats:**
    - Params (M): 23.9
    - Input image size: 384 x 384
- **Dataset:** il-all (550 classes)

- **Papers:**
    - MViTv2: Improved Multiscale Vision Transformers for Classification and Detection: <https://arxiv.org/abs/2112.01526>

## Model Usage

### Image Classification

```python
import birder
from birder.inference.classification import infer_image

(net, model_info) = birder.load_pretrained_model("mvit_v2_t_il-all", inference=True)

# Get the image size the model was trained on
size = birder.get_size_from_signature(model_info.signature)

# Create an inference transform
transform = birder.classification_transform(size, model_info.rgb_stats)

image = "path/to/image.jpeg"  # or a PIL image, must be loaded in RGB format
(out, _) = infer_image(net, image, transform)
# out is a NumPy array with shape of (1, 550), representing class probabilities.
```

### Image Embeddings

```python
import birder
from birder.inference.classification import infer_image

(net, model_info) = birder.load_pretrained_model("mvit_v2_t_il-all", inference=True)

# Get the image size the model was trained on
size = birder.get_size_from_signature(model_info.signature)

# Create an inference transform
transform = birder.classification_transform(size, model_info.rgb_stats)

image = "path/to/image.jpeg"  # or a PIL image
(out, embedding) = infer_image(net, image, transform, return_embedding=True)
# embedding is a NumPy array with shape of (1, 768)
```

### Detection Feature Map

```python
from PIL import Image
import birder

(net, model_info) = birder.load_pretrained_model("mvit_v2_t_il-all", inference=True)

# Get the image size the model was trained on
size = birder.get_size_from_signature(model_info.signature)

# Create an inference transform
transform = birder.classification_transform(size, model_info.rgb_stats)

image = Image.open("path/to/image.jpeg")
features = net.detection_features(transform(image).unsqueeze(0))
# features is a dict (stage name -> torch.Tensor)
print([(k, v.size()) for k, v in features.items()])
# Output example:
# [('stage1', torch.Size([1, 96, 96, 96])),
#  ('stage2', torch.Size([1, 192, 48, 48])),
#  ('stage3', torch.Size([1, 384, 24, 24])),
#  ('stage4', torch.Size([1, 768, 12, 12]))]
```

## Citation

```bibtex
@misc{li2022mvitv2improvedmultiscalevision,
      title={MViTv2: Improved Multiscale Vision Transformers for Classification and Detection},
      author={Yanghao Li and Chao-Yuan Wu and Haoqi Fan and Karttikeya Mangalam and Bo Xiong and Jitendra Malik and Christoph Feichtenhofer},
      year={2022},
      eprint={2112.01526},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2112.01526},
}
```