mosesb's picture
Update README.md
1aa8774 verified
---
license: mit
library_name: timm
tags:
- image-classification
- mobilevit
- timm
- drowsiness-detection
- computer-vision
- pytorch
widget:
- modelId: mosesb/drowsiness-detection-mobileViT-v2
title: Drowsiness Detection with MobileViT v2
url: >-
https://huggingface.co/spaces/mosesb/drowsiness-detection-mobileViT-v2/resolve/main/output_grad_cam.jpg
datasets:
- ismailnasri20/driver-drowsiness-dataset-ddd
- yasharjebraeily/drowsy-detection-dataset
metrics:
- accuracy
- f1
- precision
- recall
base_model:
- apple/mobilevitv2-1.0-imagenet1k-256
---
# MobileViT v2 for Drowsiness Detection
This repository contains a `MobileViT v2` classification model fine-tuned to detect driver drowsiness from images. The model is a state-of-the-art, lightweight, hybrid architecture combining convolutions with Vision Transformers, making it efficient and accurate. It classifies input images into two categories: `Drowsy` and `Non Drowsy`.
This model was trained in PyTorch using the `timm` library and demonstrates high performance on an unseen test set, making it a reliable foundation for driver safety applications.
## Model Details
* **Architecture:** `mobilevitv2_200`
* **Fine-tuned on:** A combined dataset for driver drowsiness detection.
* **Classes:** `Drowsy`, `Non Drowsy`
* **Frameworks:** PyTorch, timm
## How to Get Started
You can easily use this model with the `timm` and `torch` libraries. First, ensure you have the `best_model.pt` file from this repository.
```python
# Install required libraries
!pip install timm torch torchvision
import torch
import timm
from PIL import Image
from torchvision import transforms
# --- 1. Setup Model and Preprocessing ---
# Define the same transformations used for validation/testing
val_test_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Define class names (ensure order matches training: Drowsy=0, Non Drowsy=1)
class_names = ['Drowsy', 'Non Drowsy']
# Load the model architecture
model = timm.create_model('mobilevitv2_200', pretrained=False, num_classes=2)
# Load the fine-tuned weights
model_path = 'best_model.pt'
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
# --- 2. Run Inference ---
image_path = 'path/to/your/image.jpg'
image = Image.open(image_path).convert('RGB')
# Preprocess the image
input_tensor = val_test_transform(image).unsqueeze(0) # Add batch dimension
# Get model prediction
with torch.no_grad():
output = model(input_tensor)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
top_prob, top_class_index = torch.topk(probabilities, 1)
class_name = class_names[top_class_index.item()]
confidence = top_prob.item()
print(f"Prediction: {class_name} with confidence {confidence:.4f}")
```
## Training Procedure
The model was fine-tuned on a large dataset of over 40,000 driver images. The training process involved:
- **Data Augmentation:** A strong augmentation pipeline was used for training, including `RandomResizedCrop`, `RandomHorizontalFlip`, `ColorJitter`, and `RandomErasing`.
- **Transfer Learning:** The model was initialized with weights pretrained on ImageNet, enabling robust feature extraction and fast convergence.
- **Early Stopping:** Training was halted after 30 epochs of no improvement in validation accuracy to prevent overfitting.
### Key Hyperparameters
- **Image Size:** 224x224
- **Batch Size:** 64
- **Optimizer:** AdamW (lr=1e-4)
- **Scheduler:** ExponentialLR (gamma=0.90)
- **Loss Function:** CrossEntropyLoss
![Training Results](training_plot.png)
## Evaluation
The model was evaluated on a completely **unseen test set** (from a different dataset than the primary training data) to ensure a fair assessment of its generalization capabilities.
### Key Performance Metrics
| Metric | Value | Description |
| :----: | :----: | :------------------------------------------------- |
| **Accuracy** | 98.18% | Overall correctness on the test set. |
| **APCER** | 3.57% | Rate of 'Drowsy' drivers missed (False Negatives). |
| **BPCER** | 0.00% | Rate of 'Non Drowsy' drivers flagged (False Positives). |
| **ACER** | 1.78% | Average of APCER and BPCER. |
*APCER (Attack Presentation Classification Error Rate, adapted here) is the most critical safety metric, as it measures the failure to detect a drowsy driver.*
![Confusion Matrix](output_confusion_matrix.png)
### Model Explainability (Grad-CAM)
To ensure the model is focusing on relevant facial features, Grad-CAM was used. The heatmaps confirm that the model's predictions are primarily based on the driver's eyes, mouth, and head position, which are key indicators of drowsiness.
![Grad-CAM Visualization](output_grad_cam.jpg)
## Intended Use and Limitations
This model is intended as a proof-of-concept for driver safety systems and academic research. It should not be used as the sole mechanism for preventing accidents in a production environment without further rigorous testing.
Real-world performance may vary based on:
- Lighting conditions (especially at night).
- Camera angles and distance.
- Occlusions (e.g., sunglasses, hats, hands on face).
- Individual differences not represented in the training data.
*This model card is based on the training notebook [`MobileViT_Drowsiness.ipynb`](https://github.com/mosesab/MobileViT-Drowsiness-Detection/blob/main/MobileViT_Drowsiness.ipynb).*