Spaces:
Runtime error
Runtime error
Sadanand Modak
commited on
Commit
·
488f65c
1
Parent(s):
64cc1c5
changes
Browse files
model.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from torch import nn
|
2 |
-
from torchvision.models import efficientnet_b2, EfficientNet_B2_Weights
|
3 |
|
4 |
|
5 |
def create_effnetb2_model(num_classes=3):
|
@@ -10,3 +10,13 @@ def create_effnetb2_model(num_classes=3):
|
|
10 |
param.requires_grad = False
|
11 |
model_effnetb2.classifier[1] = nn.Linear(in_features=1408, out_features=num_classes)
|
12 |
return model_effnetb2, transforms_effnetb2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from torch import nn
|
2 |
+
from torchvision.models import efficientnet_b2, EfficientNet_B2_Weights, vit_b_16, ViT_B_16_Weights
|
3 |
|
4 |
|
5 |
def create_effnetb2_model(num_classes=3):
|
|
|
10 |
param.requires_grad = False
|
11 |
model_effnetb2.classifier[1] = nn.Linear(in_features=1408, out_features=num_classes)
|
12 |
return model_effnetb2, transforms_effnetb2
|
13 |
+
|
14 |
+
|
15 |
+
def create_vitb16_model(num_classes=3):
|
16 |
+
weights_vitb16 = ViT_B_16_Weights.DEFAULT
|
17 |
+
transforms_vitb16 = weights_vitb16.transforms()
|
18 |
+
model_vitb16 = vit_b_16(weights=weights_vitb16)
|
19 |
+
for param in model_vitb16.parameters():
|
20 |
+
param.requires_grad = False
|
21 |
+
model_vitb16.heads[0] = nn.Linear(in_features=768, out_features=num_classes)
|
22 |
+
return model_vitb16, transforms_vitb16
|