Sadanand Modak commited on
Commit
488f65c
·
1 Parent(s): 64cc1c5
Files changed (1) hide show
  1. model.py +11 -1
model.py CHANGED
@@ -1,5 +1,5 @@
1
  from torch import nn
2
- from torchvision.models import efficientnet_b2, EfficientNet_B2_Weights
3
 
4
 
5
  def create_effnetb2_model(num_classes=3):
@@ -10,3 +10,13 @@ def create_effnetb2_model(num_classes=3):
10
  param.requires_grad = False
11
  model_effnetb2.classifier[1] = nn.Linear(in_features=1408, out_features=num_classes)
12
  return model_effnetb2, transforms_effnetb2
 
 
 
 
 
 
 
 
 
 
 
1
  from torch import nn
2
+ from torchvision.models import efficientnet_b2, EfficientNet_B2_Weights, vit_b_16, ViT_B_16_Weights
3
 
4
 
5
  def create_effnetb2_model(num_classes=3):
 
10
  param.requires_grad = False
11
  model_effnetb2.classifier[1] = nn.Linear(in_features=1408, out_features=num_classes)
12
  return model_effnetb2, transforms_effnetb2
13
+
14
+
15
+ def create_vitb16_model(num_classes=3):
16
+ weights_vitb16 = ViT_B_16_Weights.DEFAULT
17
+ transforms_vitb16 = weights_vitb16.transforms()
18
+ model_vitb16 = vit_b_16(weights=weights_vitb16)
19
+ for param in model_vitb16.parameters():
20
+ param.requires_grad = False
21
+ model_vitb16.heads[0] = nn.Linear(in_features=768, out_features=num_classes)
22
+ return model_vitb16, transforms_vitb16