File size: 5,600 Bytes
bb725a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
815e023
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
---
license: mit
library_name: timm
tags:
- image-classification
- mobilevit
- timm
- drowsiness-detection
- computer-vision
- pytorch
widget:
- modelId: mosesb/drowsiness-detection-mobileViT-v2
  title: Drowsiness Detection with MobileViT v2
  url: >-
    https://huggingface.co/spaces/mosesb/drowsiness-detection-mobileViT-v2/resolve/main/output_grad_cam.jpg
datasets:
- ismailnasri20/driver-drowsiness-dataset-ddd
- yasharjebraeily/drowsy-detection-dataset
metrics:
- accuracy
- f1
- precision
- recall
base_model:
- apple/mobilevitv2-1.0-imagenet1k-256
---

# MobileViT v2 for Drowsiness Detection

This repository contains a `MobileViT v2` classification model fine-tuned to detect driver drowsiness from images. The model is a state-of-the-art, lightweight, hybrid architecture combining convolutions with Vision Transformers, making it efficient and accurate. It classifies input images into two categories: `Drowsy` and `Non Drowsy`.

This model was trained in PyTorch using the `timm` library and demonstrates high performance on an unseen test set, making it a reliable foundation for driver safety applications.

## Model Details
*   **Architecture:** `mobilevitv2_200`
*   **Fine-tuned on:** A combined dataset for driver drowsiness detection.
*   **Classes:** `Drowsy`, `Non Drowsy`
*   **Frameworks:** PyTorch, timm

## How to Get Started

You can easily use this model with the `timm` and `torch` libraries. First, ensure you have the `best_model.pt` file from this repository.

```python
# Install required libraries
!pip install timm torch torchvision

import torch
import timm
from PIL import Image
from torchvision import transforms

# --- 1. Setup Model and Preprocessing ---
# Define the same transformations used for validation/testing
val_test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Define class names (ensure order matches training: Drowsy=0, Non Drowsy=1)
class_names = ['Drowsy', 'Non Drowsy']

# Load the model architecture
model = timm.create_model('mobilevitv2_200', pretrained=False, num_classes=2)

# Load the fine-tuned weights
model_path = 'best_model.pt'
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()

# --- 2. Run Inference ---
image_path = 'path/to/your/image.jpg'
image = Image.open(image_path).convert('RGB')

# Preprocess the image
input_tensor = val_test_transform(image).unsqueeze(0) # Add batch dimension

# Get model prediction
with torch.no_grad():
    output = model(input_tensor)
    probabilities = torch.nn.functional.softmax(output[0], dim=0)
    top_prob, top_class_index = torch.topk(probabilities, 1)

class_name = class_names[top_class_index.item()]
confidence = top_prob.item()

print(f"Prediction: {class_name} with confidence {confidence:.4f}")
```

## Training Procedure

The model was fine-tuned on a large dataset of over 40,000 driver images. The training process involved:
-   **Data Augmentation:** A strong augmentation pipeline was used for training, including `RandomResizedCrop`, `RandomHorizontalFlip`, `ColorJitter`, and `RandomErasing`.
-   **Transfer Learning:** The model was initialized with weights pretrained on ImageNet, enabling robust feature extraction and fast convergence.
-   **Early Stopping:** Training was halted after 30 epochs of no improvement in validation accuracy to prevent overfitting.

### Key Hyperparameters
- **Image Size:** 224x224
- **Batch Size:** 64
- **Optimizer:** AdamW (lr=1e-4)
- **Scheduler:** ExponentialLR (gamma=0.90)
- **Loss Function:** CrossEntropyLoss

![Training Results](training_plot.png)

## Evaluation

The model was evaluated on a completely **unseen test set** (from a different dataset than the primary training data) to ensure a fair assessment of its generalization capabilities.

### Key Performance Metrics
| Metric | Value  | Description                                        |
| :----: | :----: | :------------------------------------------------- |
| **Accuracy** | 98.18% | Overall correctness on the test set.           |
| **APCER**    | 3.57%  | Rate of 'Drowsy' drivers missed (False Negatives). |
| **BPCER**    | 0.00%  | Rate of 'Non Drowsy' drivers flagged (False Positives). |
| **ACER**     | 1.78%  | Average of APCER and BPCER.                        |

*APCER (Attack Presentation Classification Error Rate, adapted here) is the most critical safety metric, as it measures the failure to detect a drowsy driver.*

![Confusion Matrix](output_confusion_matrix.png)

### Model Explainability (Grad-CAM)
To ensure the model is focusing on relevant facial features, Grad-CAM was used. The heatmaps confirm that the model's predictions are primarily based on the driver's eyes, mouth, and head position, which are key indicators of drowsiness.

![Grad-CAM Visualization](output_grad_cam.jpg)

## Intended Use and Limitations
This model is intended as a proof-of-concept for driver safety systems and academic research. It should not be used as the sole mechanism for preventing accidents in a production environment without further rigorous testing.

Real-world performance may vary based on:
-   Lighting conditions (especially at night).
-   Camera angles and distance.
-   Occlusions (e.g., sunglasses, hats, hands on face).
-   Individual differences not represented in the training data.

*This model card is based on the training notebook [`MobileViT_Drowsiness.ipynb`](https://github.com/mosesab/MobileViT-Drowsiness-Detection/blob/main/MobileViT_Drowsiness.ipynb).*