File size: 953 Bytes
adf757a
488f65c
adf757a
 
 
 
 
 
 
 
 
 
488f65c
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from torch import nn
from torchvision.models import efficientnet_b2, EfficientNet_B2_Weights, vit_b_16, ViT_B_16_Weights


def create_effnetb2_model(num_classes=3):
    weights_effnetb2 = EfficientNet_B2_Weights.DEFAULT
    transforms_effnetb2 = weights_effnetb2.transforms()
    model_effnetb2 = efficientnet_b2(weights=weights_effnetb2)
    for param in model_effnetb2.parameters():
        param.requires_grad = False
    model_effnetb2.classifier[1] = nn.Linear(in_features=1408, out_features=num_classes)
    return model_effnetb2, transforms_effnetb2


def create_vitb16_model(num_classes=3):
    weights_vitb16 = ViT_B_16_Weights.DEFAULT
    transforms_vitb16 = weights_vitb16.transforms()
    model_vitb16 = vit_b_16(weights=weights_vitb16)
    for param in model_vitb16.parameters():
        param.requires_grad = False
    model_vitb16.heads[0] = nn.Linear(in_features=768, out_features=num_classes)
    return model_vitb16, transforms_vitb16