|
--- |
|
license: mit |
|
library_name: timm |
|
tags: |
|
- image-classification |
|
- mobilevit |
|
- timm |
|
- drowsiness-detection |
|
- computer-vision |
|
- pytorch |
|
widget: |
|
- modelId: mosesb/drowsiness-detection-mobileViT-v2 |
|
title: Drowsiness Detection with MobileViT v2 |
|
url: >- |
|
https://huggingface.co/spaces/mosesb/drowsiness-detection-mobileViT-v2/resolve/main/output_grad_cam.jpg |
|
datasets: |
|
- ismailnasri20/driver-drowsiness-dataset-ddd |
|
- yasharjebraeily/drowsy-detection-dataset |
|
metrics: |
|
- accuracy |
|
- f1 |
|
- precision |
|
- recall |
|
base_model: |
|
- apple/mobilevitv2-1.0-imagenet1k-256 |
|
--- |
|
|
|
# MobileViT v2 for Drowsiness Detection |
|
|
|
This repository contains a `MobileViT v2` classification model fine-tuned to detect driver drowsiness from images. The model is a state-of-the-art, lightweight, hybrid architecture combining convolutions with Vision Transformers, making it efficient and accurate. It classifies input images into two categories: `Drowsy` and `Non Drowsy`. |
|
|
|
This model was trained in PyTorch using the `timm` library and demonstrates high performance on an unseen test set, making it a reliable foundation for driver safety applications. |
|
|
|
## Model Details |
|
* **Architecture:** `mobilevitv2_200` |
|
* **Fine-tuned on:** A combined dataset for driver drowsiness detection. |
|
* **Classes:** `Drowsy`, `Non Drowsy` |
|
* **Frameworks:** PyTorch, timm |
|
|
|
## How to Get Started |
|
|
|
You can easily use this model with the `timm` and `torch` libraries. First, ensure you have the `best_model.pt` file from this repository. |
|
|
|
```python |
|
# Install required libraries |
|
!pip install timm torch torchvision |
|
|
|
import torch |
|
import timm |
|
from PIL import Image |
|
from torchvision import transforms |
|
|
|
# --- 1. Setup Model and Preprocessing --- |
|
# Define the same transformations used for validation/testing |
|
val_test_transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
|
|
# Define class names (ensure order matches training: Drowsy=0, Non Drowsy=1) |
|
class_names = ['Drowsy', 'Non Drowsy'] |
|
|
|
# Load the model architecture |
|
model = timm.create_model('mobilevitv2_200', pretrained=False, num_classes=2) |
|
|
|
# Load the fine-tuned weights |
|
model_path = 'best_model.pt' |
|
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) |
|
model.eval() |
|
|
|
# --- 2. Run Inference --- |
|
image_path = 'path/to/your/image.jpg' |
|
image = Image.open(image_path).convert('RGB') |
|
|
|
# Preprocess the image |
|
input_tensor = val_test_transform(image).unsqueeze(0) # Add batch dimension |
|
|
|
# Get model prediction |
|
with torch.no_grad(): |
|
output = model(input_tensor) |
|
probabilities = torch.nn.functional.softmax(output[0], dim=0) |
|
top_prob, top_class_index = torch.topk(probabilities, 1) |
|
|
|
class_name = class_names[top_class_index.item()] |
|
confidence = top_prob.item() |
|
|
|
print(f"Prediction: {class_name} with confidence {confidence:.4f}") |
|
``` |
|
|
|
## Training Procedure |
|
|
|
The model was fine-tuned on a large dataset of over 40,000 driver images. The training process involved: |
|
- **Data Augmentation:** A strong augmentation pipeline was used for training, including `RandomResizedCrop`, `RandomHorizontalFlip`, `ColorJitter`, and `RandomErasing`. |
|
- **Transfer Learning:** The model was initialized with weights pretrained on ImageNet, enabling robust feature extraction and fast convergence. |
|
- **Early Stopping:** Training was halted after 30 epochs of no improvement in validation accuracy to prevent overfitting. |
|
|
|
### Key Hyperparameters |
|
- **Image Size:** 224x224 |
|
- **Batch Size:** 64 |
|
- **Optimizer:** AdamW (lr=1e-4) |
|
- **Scheduler:** ExponentialLR (gamma=0.90) |
|
- **Loss Function:** CrossEntropyLoss |
|
|
|
 |
|
|
|
## Evaluation |
|
|
|
The model was evaluated on a completely **unseen test set** (from a different dataset than the primary training data) to ensure a fair assessment of its generalization capabilities. |
|
|
|
### Key Performance Metrics |
|
| Metric | Value | Description | |
|
| :----: | :----: | :------------------------------------------------- | |
|
| **Accuracy** | 98.18% | Overall correctness on the test set. | |
|
| **APCER** | 3.57% | Rate of 'Drowsy' drivers missed (False Negatives). | |
|
| **BPCER** | 0.00% | Rate of 'Non Drowsy' drivers flagged (False Positives). | |
|
| **ACER** | 1.78% | Average of APCER and BPCER. | |
|
|
|
*APCER (Attack Presentation Classification Error Rate, adapted here) is the most critical safety metric, as it measures the failure to detect a drowsy driver.* |
|
|
|
 |
|
|
|
### Model Explainability (Grad-CAM) |
|
To ensure the model is focusing on relevant facial features, Grad-CAM was used. The heatmaps confirm that the model's predictions are primarily based on the driver's eyes, mouth, and head position, which are key indicators of drowsiness. |
|
|
|
 |
|
|
|
## Intended Use and Limitations |
|
This model is intended as a proof-of-concept for driver safety systems and academic research. It should not be used as the sole mechanism for preventing accidents in a production environment without further rigorous testing. |
|
|
|
Real-world performance may vary based on: |
|
- Lighting conditions (especially at night). |
|
- Camera angles and distance. |
|
- Occlusions (e.g., sunglasses, hats, hands on face). |
|
- Individual differences not represented in the training data. |
|
|
|
*This model card is based on the training notebook [`MobileViT_Drowsiness.ipynb`](https://github.com/mosesab/MobileViT-Drowsiness-Detection/blob/main/MobileViT_Drowsiness.ipynb).* |