File size: 5,600 Bytes
bb725a1 815e023 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
---
license: mit
library_name: timm
tags:
- image-classification
- mobilevit
- timm
- drowsiness-detection
- computer-vision
- pytorch
widget:
- modelId: mosesb/drowsiness-detection-mobileViT-v2
title: Drowsiness Detection with MobileViT v2
url: >-
https://huggingface.co/spaces/mosesb/drowsiness-detection-mobileViT-v2/resolve/main/output_grad_cam.jpg
datasets:
- ismailnasri20/driver-drowsiness-dataset-ddd
- yasharjebraeily/drowsy-detection-dataset
metrics:
- accuracy
- f1
- precision
- recall
base_model:
- apple/mobilevitv2-1.0-imagenet1k-256
---
# MobileViT v2 for Drowsiness Detection
This repository contains a `MobileViT v2` classification model fine-tuned to detect driver drowsiness from images. The model is a state-of-the-art, lightweight, hybrid architecture combining convolutions with Vision Transformers, making it efficient and accurate. It classifies input images into two categories: `Drowsy` and `Non Drowsy`.
This model was trained in PyTorch using the `timm` library and demonstrates high performance on an unseen test set, making it a reliable foundation for driver safety applications.
## Model Details
* **Architecture:** `mobilevitv2_200`
* **Fine-tuned on:** A combined dataset for driver drowsiness detection.
* **Classes:** `Drowsy`, `Non Drowsy`
* **Frameworks:** PyTorch, timm
## How to Get Started
You can easily use this model with the `timm` and `torch` libraries. First, ensure you have the `best_model.pt` file from this repository.
```python
# Install required libraries
!pip install timm torch torchvision
import torch
import timm
from PIL import Image
from torchvision import transforms
# --- 1. Setup Model and Preprocessing ---
# Define the same transformations used for validation/testing
val_test_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Define class names (ensure order matches training: Drowsy=0, Non Drowsy=1)
class_names = ['Drowsy', 'Non Drowsy']
# Load the model architecture
model = timm.create_model('mobilevitv2_200', pretrained=False, num_classes=2)
# Load the fine-tuned weights
model_path = 'best_model.pt'
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
# --- 2. Run Inference ---
image_path = 'path/to/your/image.jpg'
image = Image.open(image_path).convert('RGB')
# Preprocess the image
input_tensor = val_test_transform(image).unsqueeze(0) # Add batch dimension
# Get model prediction
with torch.no_grad():
output = model(input_tensor)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
top_prob, top_class_index = torch.topk(probabilities, 1)
class_name = class_names[top_class_index.item()]
confidence = top_prob.item()
print(f"Prediction: {class_name} with confidence {confidence:.4f}")
```
## Training Procedure
The model was fine-tuned on a large dataset of over 40,000 driver images. The training process involved:
- **Data Augmentation:** A strong augmentation pipeline was used for training, including `RandomResizedCrop`, `RandomHorizontalFlip`, `ColorJitter`, and `RandomErasing`.
- **Transfer Learning:** The model was initialized with weights pretrained on ImageNet, enabling robust feature extraction and fast convergence.
- **Early Stopping:** Training was halted after 30 epochs of no improvement in validation accuracy to prevent overfitting.
### Key Hyperparameters
- **Image Size:** 224x224
- **Batch Size:** 64
- **Optimizer:** AdamW (lr=1e-4)
- **Scheduler:** ExponentialLR (gamma=0.90)
- **Loss Function:** CrossEntropyLoss

## Evaluation
The model was evaluated on a completely **unseen test set** (from a different dataset than the primary training data) to ensure a fair assessment of its generalization capabilities.
### Key Performance Metrics
| Metric | Value | Description |
| :----: | :----: | :------------------------------------------------- |
| **Accuracy** | 98.18% | Overall correctness on the test set. |
| **APCER** | 3.57% | Rate of 'Drowsy' drivers missed (False Negatives). |
| **BPCER** | 0.00% | Rate of 'Non Drowsy' drivers flagged (False Positives). |
| **ACER** | 1.78% | Average of APCER and BPCER. |
*APCER (Attack Presentation Classification Error Rate, adapted here) is the most critical safety metric, as it measures the failure to detect a drowsy driver.*

### Model Explainability (Grad-CAM)
To ensure the model is focusing on relevant facial features, Grad-CAM was used. The heatmaps confirm that the model's predictions are primarily based on the driver's eyes, mouth, and head position, which are key indicators of drowsiness.

## Intended Use and Limitations
This model is intended as a proof-of-concept for driver safety systems and academic research. It should not be used as the sole mechanism for preventing accidents in a production environment without further rigorous testing.
Real-world performance may vary based on:
- Lighting conditions (especially at night).
- Camera angles and distance.
- Occlusions (e.g., sunglasses, hats, hands on face).
- Individual differences not represented in the training data.
*This model card is based on the training notebook [`MobileViT_Drowsiness.ipynb`](https://github.com/mosesab/MobileViT-Drowsiness-Detection/blob/main/MobileViT_Drowsiness.ipynb).* |