LukeOLuck commited on
Commit
1a96a3d
·
1 Parent(s): a40fe21

init commit

Browse files
Files changed (9) hide show
  1. .gitattributes +1 -0
  2. app.py +69 -0
  3. class_names.txt +91 -0
  4. examples/automobile.png +0 -0
  5. examples/cat.png +0 -0
  6. examples/frog.png +0 -0
  7. model.py +60 -0
  8. requirements.txt +4 -0
  9. vit.pth +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ vit.pth filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import torch
4
+
5
+ from model import create_vit
6
+ from timeit import default_timer as timer
7
+ from typing import Tuple, Dict
8
+
9
+ # Setup class names
10
+ with open("class_names.txt", "r") as f:
11
+ class_names = [name.strip() for name in f.readlines()]
12
+
13
+ ### Model and transforms preparation ###
14
+ # Create model and transforms
15
+ model, _, _, transforms = create_vit(output_shape=101, classes=class_names)
16
+
17
+ model = torch.compile(model)
18
+
19
+ # Load saved weights
20
+ model.load_state_dict(
21
+ torch.load(f="vit.pth",
22
+ map_location=torch.device("cpu")) # load to CPU
23
+ )
24
+
25
+ ### Predict function ###
26
+ def predict(img) -> Tuple[Dict, float]:
27
+ # Start a timer
28
+ start_time = timer()
29
+
30
+ # Transform the input image for use with the model
31
+ img = transforms(img).unsqueeze(0) # unsqueeze = add batch dimension on 0th index
32
+
33
+ # Put model into eval mode, make prediction
34
+ model.eval()
35
+ with torch.inference_mode():
36
+ # Pass transformed image through the model and turn the prediction logits into probaiblities
37
+ pred_probs = torch.softmax(model(img), dim=1)
38
+
39
+ # Create a prediction label and prediction probability dictionary
40
+ pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
41
+
42
+ # Calculate pred time
43
+ end_time = timer()
44
+ pred_time = round(end_time - start_time, 4)
45
+
46
+ # Return pred dict and pred time
47
+ return pred_labels_and_probs, pred_time
48
+
49
+ ### 4. Gradio app ###
50
+ # Create title, description and article
51
+ title = "A ViT cifar10 Classifier"
52
+ description = "An [ViT feature extractor](https://huggingface.co/google/vit-base-patch16-224) computer vision model to classify images on the [10 classes of the cifar10 dataset](https://huggingface.co/datasets/cifar10). [Source Code Found Here](https://colab.research.google.com/drive/1j4NbiMpCqmXN1xw9e2_r77gMdr3WpMnO?usp=drive_link)"
53
+ article = "Built with [Gradio](https://github.com/gradio-app/gradio) and [PyTorch](https://pytorch.org/). [Source Code Found Here](https://colab.research.google.com/drive/1j4NbiMpCqmXN1xw9e2_r77gMdr3WpMnO?usp=drive_link)"
54
+
55
+ # Create example list
56
+ example_list = [["examples/" + example] for example in os.listdir("examples")]
57
+
58
+ # Create the Gradio demo
59
+ demo = gr.Interface(fn=predict, # maps inputs to outputs
60
+ inputs=gr.Image(type="pil"),
61
+ outputs=[gr.Label(num_top_classes=5, label="Predictions"),
62
+ gr.Number(label="Prediction time (s)")],
63
+ examples=example_list,
64
+ title=title,
65
+ description=description,
66
+ article=article)
67
+
68
+ # Launch the demo
69
+ demo.launch()
class_names.txt ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ airplane
2
+ automobile
3
+ bird
4
+ cat
5
+ deer
6
+ dog
7
+ frog
8
+ horse
9
+ ship
10
+ truckairplane
11
+ automobile
12
+ bird
13
+ cat
14
+ deer
15
+ dog
16
+ frog
17
+ horse
18
+ ship
19
+ truckairplane
20
+ automobile
21
+ bird
22
+ cat
23
+ deer
24
+ dog
25
+ frog
26
+ horse
27
+ ship
28
+ truckairplane
29
+ automobile
30
+ bird
31
+ cat
32
+ deer
33
+ dog
34
+ frog
35
+ horse
36
+ ship
37
+ truckairplane
38
+ automobile
39
+ bird
40
+ cat
41
+ deer
42
+ dog
43
+ frog
44
+ horse
45
+ ship
46
+ truckairplane
47
+ automobile
48
+ bird
49
+ cat
50
+ deer
51
+ dog
52
+ frog
53
+ horse
54
+ ship
55
+ truckairplane
56
+ automobile
57
+ bird
58
+ cat
59
+ deer
60
+ dog
61
+ frog
62
+ horse
63
+ ship
64
+ truckairplane
65
+ automobile
66
+ bird
67
+ cat
68
+ deer
69
+ dog
70
+ frog
71
+ horse
72
+ ship
73
+ truckairplane
74
+ automobile
75
+ bird
76
+ cat
77
+ deer
78
+ dog
79
+ frog
80
+ horse
81
+ ship
82
+ truckairplane
83
+ automobile
84
+ bird
85
+ cat
86
+ deer
87
+ dog
88
+ frog
89
+ horse
90
+ ship
91
+ truck
examples/automobile.png ADDED
examples/cat.png ADDED
examples/frog.png ADDED
model.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ from torch import nn
4
+ from torchvision import transforms
5
+ from transformers import ViTForImageClassification
6
+ from transformers import ViTImageProcessor
7
+ from typing import List
8
+
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
+ def create_vit(output_shape:int, classes:List, device:torch.device=device):
12
+ """Creates a HuggingFace ViT model google/vit-base-patch16-224
13
+
14
+ Args:
15
+ output_shape: The output shape
16
+ classes: A list of classes
17
+ device: A torch.device
18
+
19
+ Returns:
20
+ A tuple of the model, train_transforms, val_transforms, test_transforms
21
+ """
22
+ id2label = {id:label for id, label in enumerate(classes)}
23
+ label2id = {label:id for id,label in id2label.items()}
24
+
25
+ model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224',
26
+ num_labels=len(classes),
27
+ id2label=id2label,
28
+ label2id=label2id,
29
+ ignore_mismatched_sizes=True)
30
+
31
+ for param in model.parameters():
32
+ param.requires_grad = False
33
+
34
+ # Can add dropout here if needed
35
+ model.classifier = nn.Linear(in_features=768, out_features=output_shape)
36
+
37
+ #https://github.com/NielsRogge/Transformers-Tutorials/blob/master/VisionTransformer/Fine_tuning_the_Vision_Transformer_on_CIFAR_10_with_PyTorch_Lightning.ipynb
38
+ processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
39
+ image_mean = processor.image_mean
40
+ image_std = processor.image_std
41
+ size = processor.size["height"]
42
+
43
+ normalize = transforms.Normalize(mean=image_mean, std=image_std)
44
+ train_transforms = transforms.Compose([
45
+ #transforms.RandomResizedCrop(size),
46
+ transforms.Resize(size),
47
+ transforms.CenterCrop(size),
48
+ transforms.RandomHorizontalFlip(),
49
+ transforms.ToTensor(),
50
+ normalize])
51
+
52
+ val_transforms = transforms.Compose([
53
+ transforms.Resize(size),
54
+ transforms.CenterCrop(size),
55
+ transforms.ToTensor(),
56
+ normalize])
57
+
58
+ test_transforms = val_transforms
59
+
60
+ return model.to(device), train_transforms, val_transforms, test_transforms
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch==2.1.0
2
+ torchvision==0.16.0
3
+ gradio==3.50.2
4
+ transformers==4.35.0
vit.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f92179e0f94cb25af399d2e2e324a2390fd9e9c842728f866451b6b3d7db625
3
+ size 343297010