import torch import torch.nn as nn import torchvision.transforms as transforms from PIL import Image import gradio as gr # Definisikan ulang model sesuai dengan struktur aslinya class SimpleCNN(nn.Module): def __init__(self): super(SimpleCNN, self).__init__() self.conv_layers = nn.Sequential( nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2) ) self.fc_layers = nn.Sequential( nn.Flatten(), nn.Linear(128 * 16 * 16, 128), # Pastikan ukuran input sesuai nn.ReLU(), nn.Linear(128, 6) # Output layer untuk 6 kelas ) def forward(self, x): x = self.conv_layers(x) x = self.fc_layers(x) return x # Inisialisasi model model = SimpleCNN() # Load model dengan error handling try: model.load_state_dict(torch.load("model_deri.pth", map_location=torch.device("cpu")), strict=False) model.eval() print("✅ Model berhasil dimuat!") except Exception as e: print(f"❌ Error loading model: {e}") # Kelas mapping class_mapping = {0: 'Bu dian', 1: 'Deri', 2: 'Putra', 3: 'Unknown', 4: 'Uqi', 5: 'Uwa'} # Transformasi gambar transform = transforms.Compose([ transforms.Resize((128, 128)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) def predict(image): image = transform(image).unsqueeze(0) # Tambah batch dimension with torch.no_grad(): output = model(image) probabilities = torch.nn.functional.softmax(output, dim=1) predicted_class = torch.argmax(probabilities, dim=1).item() confidence = probabilities[0, predicted_class].item() * 100 # Konversi ke persen return f"Predicted: {class_mapping[predicted_class]} (Confidence: {confidence:.2f}%)" # Buat UI Gradio iface = gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs="text") # Jalankan aplikasi if __name__ == "__main__": try: iface.launch() except Exception as e: print(f"❌ Gradio error: {e}")