ResNet18 Fine-Tuned on CIFAR-10
This model is a fine-tuned version of ResNet18 (originally pretrained on ImageNet) on the CIFAR-10 dataset. It achieves the following results on the validation/test set:
- Validation Accuracy:
88.60%
Model description
- Architecture: ResNet18 with the final fully-connected layer replaced by a 10-class output layer for CIFAR-10 (airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck).
- Pretrained Weights: ImageNet1K
- Fine-Tuning: The model was fine-tuned on CIFAR-10 images resized to 128×128 pixels.
- Data Augmentation: Random horizontal flip, random rotation, normalization to mean=0.5 and std=0.5.
Intended uses & limitations
- Intended use: Educational/demo purposes or as a starting point for further fine-tuning on similar image classification tasks.
- Not intended for: Production-critical tasks without further evaluation, as CIFAR-10 is relatively small-scale, and the model may not generalize to non-CIFAR data without additional fine-tuning.
Training procedure
Hyperparameters (approximate):
- optimizer: Adam
- learning_rate: 1e-3
- batch_size: 32
- num_epochs: 15
GPU/CPU:
- This model was trained on a single GPU (
torch.device("cuda")
) if available, otherwise CPU.
Training logs (for each epoch on the training set):
Epoch | Training Loss | Training Accuracy | Validation Accuracy |
---|---|---|---|
1 | 0.7013 | 76.52% | - |
2 | 0.4248 | 85.64% | - |
3 | 0.3185 | 89.07% | - |
4 | 0.2341 | 92.06% | - |
5 | 0.1762 | 93.86% | - |
6 | 0.1302 | 95.55% | - |
7 | 0.1085 | 96.31% | - |
8 | 0.0925 | 96.82% | - |
9 | 0.0765 | 97.37% | - |
10 | 0.0683 | 97.68% | - |
11 | 0.0655 | 97.83% | - |
12 | 0.0548 | 98.18% | - |
13 | 0.0513 | 98.27% | - |
14 | 0.0461 | 98.49% | - |
15 | 0.0470 | 98.41% | 88.60% |
Note: Validation accuracy was computed at the end of training (final epoch).
Usage
Below is a sample usage snippet in Python. Replace username/model_repo_name
with the actual model repo id on Hugging Face.
import torch
import torch.nn as nn
from torchvision import models, transforms
from huggingface_hub import hf_hub_download
from PIL import Image
# Download the weights from the Hugging Face Hub
ckpt_path = hf_hub_download(repo_id="username/model_repo_name", filename="cnn_model.pth")
# Define the same model architecture
model = models.resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, 10) # for CIFAR-10
model.load_state_dict(torch.load(ckpt_path, map_location="cpu"))
model.eval()
# Define transforms
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
# Example inference
image = Image.open("your_image.jpg").convert("RGB")
input_tensor = transform(image).unsqueeze(0) # add batch dimension
with torch.no_grad():
logits = model(input_tensor)
predicted_class = logits.argmax(dim=1).item()
print("Predicted class ID:", predicted_class)
Inference Providers
NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API:
The model has no library tag.
Model tree for teguhteja/ttm_cnn_model
Base model
microsoft/resnet-18