MobileViT v2 for Drowsiness Detection

This repository contains a MobileViT v2 classification model fine-tuned to detect driver drowsiness from images. The model is a state-of-the-art, lightweight, hybrid architecture combining convolutions with Vision Transformers, making it efficient and accurate. It classifies input images into two categories: Drowsy and Non Drowsy.

This model was trained in PyTorch using the timm library and demonstrates high performance on an unseen test set, making it a reliable foundation for driver safety applications.

Model Details

  • Architecture: mobilevitv2_200
  • Fine-tuned on: A combined dataset for driver drowsiness detection.
  • Classes: Drowsy, Non Drowsy
  • Frameworks: PyTorch, timm

How to Get Started

You can easily use this model with the timm and torch libraries. First, ensure you have the best_model.pt file from this repository.

# Install required libraries
!pip install timm torch torchvision

import torch
import timm
from PIL import Image
from torchvision import transforms

# --- 1. Setup Model and Preprocessing ---
# Define the same transformations used for validation/testing
val_test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Define class names (ensure order matches training: Drowsy=0, Non Drowsy=1)
class_names = ['Drowsy', 'Non Drowsy']

# Load the model architecture
model = timm.create_model('mobilevitv2_200', pretrained=False, num_classes=2)

# Load the fine-tuned weights
model_path = 'best_model.pt'
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()

# --- 2. Run Inference ---
image_path = 'path/to/your/image.jpg'
image = Image.open(image_path).convert('RGB')

# Preprocess the image
input_tensor = val_test_transform(image).unsqueeze(0) # Add batch dimension

# Get model prediction
with torch.no_grad():
    output = model(input_tensor)
    probabilities = torch.nn.functional.softmax(output[0], dim=0)
    top_prob, top_class_index = torch.topk(probabilities, 1)

class_name = class_names[top_class_index.item()]
confidence = top_prob.item()

print(f"Prediction: {class_name} with confidence {confidence:.4f}")

Training Procedure

The model was fine-tuned on a large dataset of over 40,000 driver images. The training process involved:

  • Data Augmentation: A strong augmentation pipeline was used for training, including RandomResizedCrop, RandomHorizontalFlip, ColorJitter, and RandomErasing.
  • Transfer Learning: The model was initialized with weights pretrained on ImageNet, enabling robust feature extraction and fast convergence.
  • Early Stopping: Training was halted after 30 epochs of no improvement in validation accuracy to prevent overfitting.

Key Hyperparameters

  • Image Size: 224x224
  • Batch Size: 64
  • Optimizer: AdamW (lr=1e-4)
  • Scheduler: ExponentialLR (gamma=0.90)
  • Loss Function: CrossEntropyLoss

Training Results

Evaluation

The model was evaluated on a completely unseen test set (from a different dataset than the primary training data) to ensure a fair assessment of its generalization capabilities.

Key Performance Metrics

Metric Value Description
Accuracy 98.18% Overall correctness on the test set.
APCER 3.57% Rate of 'Drowsy' drivers missed (False Negatives).
BPCER 0.00% Rate of 'Non Drowsy' drivers flagged (False Positives).
ACER 1.78% Average of APCER and BPCER.

APCER (Attack Presentation Classification Error Rate, adapted here) is the most critical safety metric, as it measures the failure to detect a drowsy driver.

Confusion Matrix

Model Explainability (Grad-CAM)

To ensure the model is focusing on relevant facial features, Grad-CAM was used. The heatmaps confirm that the model's predictions are primarily based on the driver's eyes, mouth, and head position, which are key indicators of drowsiness.

Grad-CAM Visualization

Intended Use and Limitations

This model is intended as a proof-of-concept for driver safety systems and academic research. It should not be used as the sole mechanism for preventing accidents in a production environment without further rigorous testing.

Real-world performance may vary based on:

  • Lighting conditions (especially at night).
  • Camera angles and distance.
  • Occlusions (e.g., sunglasses, hats, hands on face).
  • Individual differences not represented in the training data.

This model card is based on the training notebook MobileViT_Drowsiness.ipynb.

Downloads last month
0
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for mosesb/drowsiness-detection-mobileViT-v2

Finetuned
(11)
this model