tags: - image-classification - pytorch - resnet - cifar10 license: apache-2.0 # Choose an appropriate license (e.g., apache-2.0, mit, etc.)
ResNet-18 for CIFAR-10 Image Classification
This is a ResNet-18 model fine-tuned on the CIFAR-10 dataset for image classification.
Model Details:
- Architecture: ResNet-18, pre-trained on ImageNet.
- Dataset: Fine-tuned on the CIFAR-10 dataset.
- Task: Image Classification.
- Classes: Classifies images into 10 categories: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck.
- Framework: PyTorch.
How to Use:
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
# Load the model architecture (you need to define your CIFAR10_ResNet class, or adapt a standard ResNet)
model = models.resnet18(num_classes=10) # Or however you adapted ResNet for 10 classes
model.load_state_dict(torch.load('resnet18_cifar10.pth'))
model.eval()
# Image preprocessing (use the same transforms as during training)
preprocess = transforms.Compose([
transforms.Resize(256), # Or whatever input size your model expects
transforms.CenterCrop(224), # Or input crop size
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet normalization
])
# Load and preprocess an example image (replace 'your_image.jpg')
image_path = 'your_image.jpg'
img = Image.open(image_path)
img_t = preprocess(img)
batch_t = torch.unsqueeze(img_t, 0)
# Inference
with torch.no_grad():
output = model(batch_t)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
_, predicted_class = torch.max(probabilities, 0)
# ... (Code to map predicted_class index to class name if you have class names) ...
print(f"Predicted class index: {predicted_class.item()}")
print(f"Probabilities: {probabilities.tolist()}")
Inference Providers
NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API:
The model has no library tag.
Model tree for bhumong/resnet18-cifar10
Base model
microsoft/resnet-18