MobileViT v2 for Drowsiness Detection
This repository contains a MobileViT v2
classification model fine-tuned to detect driver drowsiness from images. The model is a state-of-the-art, lightweight, hybrid architecture combining convolutions with Vision Transformers, making it efficient and accurate. It classifies input images into two categories: Drowsy
and Non Drowsy
.
This model was trained in PyTorch using the timm
library and demonstrates high performance on an unseen test set, making it a reliable foundation for driver safety applications.
Model Details
- Architecture:
mobilevitv2_200
- Fine-tuned on: A combined dataset for driver drowsiness detection.
- Classes:
Drowsy
,Non Drowsy
- Frameworks: PyTorch, timm
How to Get Started
You can easily use this model with the timm
and torch
libraries. First, ensure you have the best_model.pt
file from this repository.
# Install required libraries
!pip install timm torch torchvision
import torch
import timm
from PIL import Image
from torchvision import transforms
# --- 1. Setup Model and Preprocessing ---
# Define the same transformations used for validation/testing
val_test_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Define class names (ensure order matches training: Drowsy=0, Non Drowsy=1)
class_names = ['Drowsy', 'Non Drowsy']
# Load the model architecture
model = timm.create_model('mobilevitv2_200', pretrained=False, num_classes=2)
# Load the fine-tuned weights
model_path = 'best_model.pt'
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
# --- 2. Run Inference ---
image_path = 'path/to/your/image.jpg'
image = Image.open(image_path).convert('RGB')
# Preprocess the image
input_tensor = val_test_transform(image).unsqueeze(0) # Add batch dimension
# Get model prediction
with torch.no_grad():
output = model(input_tensor)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
top_prob, top_class_index = torch.topk(probabilities, 1)
class_name = class_names[top_class_index.item()]
confidence = top_prob.item()
print(f"Prediction: {class_name} with confidence {confidence:.4f}")
Training Procedure
The model was fine-tuned on a large dataset of over 40,000 driver images. The training process involved:
- Data Augmentation: A strong augmentation pipeline was used for training, including
RandomResizedCrop
,RandomHorizontalFlip
,ColorJitter
, andRandomErasing
. - Transfer Learning: The model was initialized with weights pretrained on ImageNet, enabling robust feature extraction and fast convergence.
- Early Stopping: Training was halted after 30 epochs of no improvement in validation accuracy to prevent overfitting.
Key Hyperparameters
- Image Size: 224x224
- Batch Size: 64
- Optimizer: AdamW (lr=1e-4)
- Scheduler: ExponentialLR (gamma=0.90)
- Loss Function: CrossEntropyLoss
Evaluation
The model was evaluated on a completely unseen test set (from a different dataset than the primary training data) to ensure a fair assessment of its generalization capabilities.
Key Performance Metrics
Metric | Value | Description |
---|---|---|
Accuracy | 98.18% | Overall correctness on the test set. |
APCER | 3.57% | Rate of 'Drowsy' drivers missed (False Negatives). |
BPCER | 0.00% | Rate of 'Non Drowsy' drivers flagged (False Positives). |
ACER | 1.78% | Average of APCER and BPCER. |
APCER (Attack Presentation Classification Error Rate, adapted here) is the most critical safety metric, as it measures the failure to detect a drowsy driver.
Model Explainability (Grad-CAM)
To ensure the model is focusing on relevant facial features, Grad-CAM was used. The heatmaps confirm that the model's predictions are primarily based on the driver's eyes, mouth, and head position, which are key indicators of drowsiness.
Intended Use and Limitations
This model is intended as a proof-of-concept for driver safety systems and academic research. It should not be used as the sole mechanism for preventing accidents in a production environment without further rigorous testing.
Real-world performance may vary based on:
- Lighting conditions (especially at night).
- Camera angles and distance.
- Occlusions (e.g., sunglasses, hats, hands on face).
- Individual differences not represented in the training data.
This model card is based on the training notebook MobileViT_Drowsiness.ipynb
.
- Downloads last month
- 0
Model tree for mosesb/drowsiness-detection-mobileViT-v2
Base model
apple/mobilevitv2-1.0-imagenet1k-256