ResNet18 Fine-Tuned on CIFAR-10

This model is a fine-tuned version of ResNet18 (originally pretrained on ImageNet) on the CIFAR-10 dataset. It achieves the following results on the validation/test set:

  • Validation Accuracy: 88.60%

Model description

  • Architecture: ResNet18 with the final fully-connected layer replaced by a 10-class output layer for CIFAR-10 (airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck).
  • Pretrained Weights: ImageNet1K
  • Fine-Tuning: The model was fine-tuned on CIFAR-10 images resized to 128×128 pixels.
  • Data Augmentation: Random horizontal flip, random rotation, normalization to mean=0.5 and std=0.5.

Intended uses & limitations

  • Intended use: Educational/demo purposes or as a starting point for further fine-tuning on similar image classification tasks.
  • Not intended for: Production-critical tasks without further evaluation, as CIFAR-10 is relatively small-scale, and the model may not generalize to non-CIFAR data without additional fine-tuning.

Training procedure

Hyperparameters (approximate):

  • optimizer: Adam
  • learning_rate: 1e-3
  • batch_size: 32
  • num_epochs: 15

GPU/CPU:

  • This model was trained on a single GPU (torch.device("cuda")) if available, otherwise CPU.

Training logs (for each epoch on the training set):

Epoch Training Loss Training Accuracy Validation Accuracy
1 0.7013 76.52% -
2 0.4248 85.64% -
3 0.3185 89.07% -
4 0.2341 92.06% -
5 0.1762 93.86% -
6 0.1302 95.55% -
7 0.1085 96.31% -
8 0.0925 96.82% -
9 0.0765 97.37% -
10 0.0683 97.68% -
11 0.0655 97.83% -
12 0.0548 98.18% -
13 0.0513 98.27% -
14 0.0461 98.49% -
15 0.0470 98.41% 88.60%

Note: Validation accuracy was computed at the end of training (final epoch).


Usage

Below is a sample usage snippet in Python. Replace username/model_repo_name with the actual model repo id on Hugging Face.

import torch
import torch.nn as nn
from torchvision import models, transforms
from huggingface_hub import hf_hub_download
from PIL import Image

# Download the weights from the Hugging Face Hub
ckpt_path = hf_hub_download(repo_id="username/model_repo_name", filename="cnn_model.pth")

# Define the same model architecture
model = models.resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, 10)  # for CIFAR-10
model.load_state_dict(torch.load(ckpt_path, map_location="cpu"))
model.eval()

# Define transforms
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# Example inference
image = Image.open("your_image.jpg").convert("RGB")
input_tensor = transform(image).unsqueeze(0)  # add batch dimension
with torch.no_grad():
    logits = model(input_tensor)
    predicted_class = logits.argmax(dim=1).item()

print("Predicted class ID:", predicted_class)
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.

Model tree for teguhteja/ttm_cnn_model

Finetuned
(25)
this model

Dataset used to train teguhteja/ttm_cnn_model