Spaces:
Runtime error
Runtime error
File size: 953 Bytes
adf757a 488f65c adf757a 488f65c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
from torch import nn
from torchvision.models import efficientnet_b2, EfficientNet_B2_Weights, vit_b_16, ViT_B_16_Weights
def create_effnetb2_model(num_classes=3):
weights_effnetb2 = EfficientNet_B2_Weights.DEFAULT
transforms_effnetb2 = weights_effnetb2.transforms()
model_effnetb2 = efficientnet_b2(weights=weights_effnetb2)
for param in model_effnetb2.parameters():
param.requires_grad = False
model_effnetb2.classifier[1] = nn.Linear(in_features=1408, out_features=num_classes)
return model_effnetb2, transforms_effnetb2
def create_vitb16_model(num_classes=3):
weights_vitb16 = ViT_B_16_Weights.DEFAULT
transforms_vitb16 = weights_vitb16.transforms()
model_vitb16 = vit_b_16(weights=weights_vitb16)
for param in model_vitb16.parameters():
param.requires_grad = False
model_vitb16.heads[0] = nn.Linear(in_features=768, out_features=num_classes)
return model_vitb16, transforms_vitb16
|