roylvzn commited on
Commit
30a63a9
·
verified ·
1 Parent(s): ea99520

init: uploading model and architecture

Browse files
Files changed (4) hide show
  1. README.md +53 -3
  2. cifar10_classes.json +1 -0
  3. pytorch_model.bin +3 -0
  4. vit_model.py +11 -0
README.md CHANGED
@@ -1,3 +1,53 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: timm
3
+ license: apache-2.0
4
+ datasets:
5
+ - cifar10
6
+ tags:
7
+ - vision
8
+ - image-classification
9
+ - cifar10
10
+ - vit
11
+ model-index:
12
+ - name: vit-cifar10
13
+ results:
14
+ - task: {type: image-classification}
15
+ dataset: {name: CIFAR-10, type: cifar10}
16
+ metrics:
17
+ - type: accuracy
18
+ value: 0.95 # replace with your test accuracy
19
+ ---
20
+
21
+ # ViT Base (patch16, 224) fine-tuned on CIFAR-10
22
+
23
+ Trained on CIFAR-10 (10 classes). Weights saved as a plain PyTorch `state_dict` (`pytorch_model.bin`).
24
+ Architecture is defined in `vit_model.py` (uses `timm`).
25
+
26
+ ## Usage
27
+
28
+ ```python
29
+ import torch, json
30
+ from huggingface_hub import hf_hub_download
31
+ import importlib.util
32
+
33
+ repo_id = "roylvzn/vit-cifar10"
34
+
35
+ # fetch files
36
+ weights_path = hf_hub_download(repo_id, "pytorch_model.bin")
37
+ model_py = hf_hub_download(repo_id, "vit_model.py")
38
+ classes_path = hf_hub_download(repo_id, "classes.json")
39
+
40
+ # import vit_model.py dynamically
41
+ spec = importlib.util.spec_from_file_location("vit_model", model_py)
42
+ vm = importlib.util.module_from_spec(spec); spec.loader.exec_module(vm)
43
+
44
+ # build model and load weights
45
+ model = vm.ViTModel(num_classes=10, pretrained=False)
46
+ state = torch.load(weights_path, map_location="cpu")
47
+ model.load_state_dict(state)
48
+ model.eval()
49
+
50
+ with open(classes_path) as f:
51
+ classes = json.load(f)
52
+
53
+ # inference expects 224x224 ImageNet-normalized tensors
cifar10_classes.json ADDED
@@ -0,0 +1 @@
 
 
1
+ ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cad4c8468fd4eab092a7b0c7f6c7bdd9b0bf4c337d255967e81594736b3beff2
3
+ size 343286237
vit_model.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import timm
2
+ import torch.nn as nn
3
+
4
+ class ViTModel(nn.Module):
5
+ def __init__(self, num_classes):
6
+ super(ViTModel, self).__init__()
7
+ self.model = timm.create_model('vit_base_patch16_224', pretrained=False)
8
+ self.model.head = nn.Linear(self.model.head.in_features,num_classes)
9
+
10
+ def forward(self, x):
11
+ return self.model(x)