Sadanand Modak
changes
488f65c
raw
history blame
953 Bytes
from torch import nn
from torchvision.models import efficientnet_b2, EfficientNet_B2_Weights, vit_b_16, ViT_B_16_Weights
def create_effnetb2_model(num_classes=3):
weights_effnetb2 = EfficientNet_B2_Weights.DEFAULT
transforms_effnetb2 = weights_effnetb2.transforms()
model_effnetb2 = efficientnet_b2(weights=weights_effnetb2)
for param in model_effnetb2.parameters():
param.requires_grad = False
model_effnetb2.classifier[1] = nn.Linear(in_features=1408, out_features=num_classes)
return model_effnetb2, transforms_effnetb2
def create_vitb16_model(num_classes=3):
weights_vitb16 = ViT_B_16_Weights.DEFAULT
transforms_vitb16 = weights_vitb16.transforms()
model_vitb16 = vit_b_16(weights=weights_vitb16)
for param in model_vitb16.parameters():
param.requires_grad = False
model_vitb16.heads[0] = nn.Linear(in_features=768, out_features=num_classes)
return model_vitb16, transforms_vitb16