ruminasval's picture
Upload 2 files
6886af2 verified
raw
history blame contribute delete
950 Bytes
import torch
import timm
from torchvision import transforms
from PIL import Image
# Load model
def load_model(model_path, num_classes=5):
model = timm.create_model('swin_tiny_patch4_window7_224', pretrained=False, num_classes=num_classes)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
return model
# Preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])
# Kelas
class_names = ['Heart', 'Oblong', 'Oval', 'Round', 'Square']
# Prediksi bentuk wajah dari gambar
def predict_face_shape(model, image):
image = image.convert("RGB")
img_tensor = transform(image).unsqueeze(0) # Tambah batch dimension
with torch.no_grad():
output = model(img_tensor)
pred = torch.argmax(output, dim=1).item()
return class_names[pred]