πΆ vit-base-oxford-iiit-pets
A fine-tuned Vision Transformer (ViT-Base) model on the Oxford-IIIT Pet Dataset for fine-grained image classification of 37 cat and dog breeds.
π Evaluation Results
- Final Validation Loss:
0.2496
- Final Accuracy:
93.10%
Training Loss | Epoch | Step | Validation Loss | Accuracy |
---|---|---|---|---|
1.7389 | 1.0 | 185 | 0.4195 | 0.8999 |
0.3005 | 2.0 | 370 | 0.2997 | 0.9161 |
0.2178 | 3.0 | 555 | 0.2647 | 0.9256 |
0.1928 | 4.0 | 740 | 0.2533 | 0.9323 |
0.1734 | 5.0 | 925 | 0.2496 | 0.9310 |
π§ Model Description
- Architecture: Vision Transformer (ViT-Base)
- Patch Size: 16x16
- Input Resolution: 224x224
- Classes: 37 (cat and dog breeds)
- Base Model:
google/vit-base-patch16-224
- Trained on: Oxford-IIIT Pet Dataset
π¦ Intended Use
- Pet breed classification (cats & dogs)
- Transfer learning on similar classification tasks
- Educational and benchmarking use for ViT models
β οΈ Limitations
- Only trained on pet images β may not generalize to wild animals
- Not optimized for real-time or mobile deployment
- Sensitive to poor image quality or unusual aspect ratios
π§ͺ Dataset
- Name: Oxford-IIIT Pet Dataset
- Images: 7,349 total
- Classes: 37 pet breeds
- Split: Random train-validation (80/20)
π οΈ Training Details
Hyperparameters
learning_rate
: 3e-4train_batch_size
: 32eval_batch_size
: 16num_train_epochs
: 5optimizer
: AdamW (betas=(0.9, 0.999)
,epsilon=1e-08
)lr_scheduler_type
: linearseed
: 42
Framework Versions
Transformers
: 4.51.3PyTorch
: 2.6.0+cu124Datasets
: 3.6.0Tokenizers
: 0.21.1
π How to Use
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
from PIL import Image
import torch
# Load model and feature extractor
model = AutoModelForImageClassification.from_pretrained("rakib730/vit-base-oxford-iiit-pets")
processor = AutoFeatureExtractor.from_pretrained("rakib730/vit-base-oxford-iiit-pets")
# Load an image
image = Image.open("your_image.jpg")
inputs = processor(images=image, return_tensors="pt")
# Predict
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_label = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_label])
- Downloads last month
- 5
Model tree for rakib730/vit-base-oxford-iiit-pets
Base model
google/vit-base-patch16-224