import torchvision import torch from torchvision import transforms from torch import nn def create_model(num_of_classes:int=3): weights=torchvision.models.MobileNet_V3_Large_Weights.DEFAULT transform=weights.transforms() model=torchvision.models.mobilenet_v3_large(weights=weights) for parameter in model.parameters(): parameter.requires_grad=False for parameter in model.classifier[-4:].parameters(): parameter.requires_grad=True for parameter in model.features[-3:].parameters(): parameter.requires_grad=True model.classifier[3]=nn.Sequential(nn.Linear(1280,1000),nn.ReLU(),nn.Dropout(p=0.3),nn.Linear(1000,num_of_classes)) return model,transform