--- license: mit library_name: timm tags: - image-classification - mobilevit - timm - drowsiness-detection - computer-vision - pytorch widget: - modelId: mosesb/drowsiness-detection-mobileViT-v2 title: Drowsiness Detection with MobileViT v2 url: >- https://huggingface.co/spaces/mosesb/drowsiness-detection-mobileViT-v2/resolve/main/output_grad_cam.jpg datasets: - ismailnasri20/driver-drowsiness-dataset-ddd - yasharjebraeily/drowsy-detection-dataset metrics: - accuracy - f1 - precision - recall base_model: - apple/mobilevitv2-1.0-imagenet1k-256 --- # MobileViT v2 for Drowsiness Detection This repository contains a `MobileViT v2` classification model fine-tuned to detect driver drowsiness from images. The model is a state-of-the-art, lightweight, hybrid architecture combining convolutions with Vision Transformers, making it efficient and accurate. It classifies input images into two categories: `Drowsy` and `Non Drowsy`. This model was trained in PyTorch using the `timm` library and demonstrates high performance on an unseen test set, making it a reliable foundation for driver safety applications. ## Model Details * **Architecture:** `mobilevitv2_200` * **Fine-tuned on:** A combined dataset for driver drowsiness detection. * **Classes:** `Drowsy`, `Non Drowsy` * **Frameworks:** PyTorch, timm ## How to Get Started You can easily use this model with the `timm` and `torch` libraries. First, ensure you have the `best_model.pt` file from this repository. ```python # Install required libraries !pip install timm torch torchvision import torch import timm from PIL import Image from torchvision import transforms # --- 1. Setup Model and Preprocessing --- # Define the same transformations used for validation/testing val_test_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # Define class names (ensure order matches training: Drowsy=0, Non Drowsy=1) class_names = ['Drowsy', 'Non Drowsy'] # Load the model architecture model = timm.create_model('mobilevitv2_200', pretrained=False, num_classes=2) # Load the fine-tuned weights model_path = 'best_model.pt' model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) model.eval() # --- 2. Run Inference --- image_path = 'path/to/your/image.jpg' image = Image.open(image_path).convert('RGB') # Preprocess the image input_tensor = val_test_transform(image).unsqueeze(0) # Add batch dimension # Get model prediction with torch.no_grad(): output = model(input_tensor) probabilities = torch.nn.functional.softmax(output[0], dim=0) top_prob, top_class_index = torch.topk(probabilities, 1) class_name = class_names[top_class_index.item()] confidence = top_prob.item() print(f"Prediction: {class_name} with confidence {confidence:.4f}") ``` ## Training Procedure The model was fine-tuned on a large dataset of over 40,000 driver images. The training process involved: - **Data Augmentation:** A strong augmentation pipeline was used for training, including `RandomResizedCrop`, `RandomHorizontalFlip`, `ColorJitter`, and `RandomErasing`. - **Transfer Learning:** The model was initialized with weights pretrained on ImageNet, enabling robust feature extraction and fast convergence. - **Early Stopping:** Training was halted after 30 epochs of no improvement in validation accuracy to prevent overfitting. ### Key Hyperparameters - **Image Size:** 224x224 - **Batch Size:** 64 - **Optimizer:** AdamW (lr=1e-4) - **Scheduler:** ExponentialLR (gamma=0.90) - **Loss Function:** CrossEntropyLoss ![Training Results](training_plot.png) ## Evaluation The model was evaluated on a completely **unseen test set** (from a different dataset than the primary training data) to ensure a fair assessment of its generalization capabilities. ### Key Performance Metrics | Metric | Value | Description | | :----: | :----: | :------------------------------------------------- | | **Accuracy** | 98.18% | Overall correctness on the test set. | | **APCER** | 3.57% | Rate of 'Drowsy' drivers missed (False Negatives). | | **BPCER** | 0.00% | Rate of 'Non Drowsy' drivers flagged (False Positives). | | **ACER** | 1.78% | Average of APCER and BPCER. | *APCER (Attack Presentation Classification Error Rate, adapted here) is the most critical safety metric, as it measures the failure to detect a drowsy driver.* ![Confusion Matrix](output_confusion_matrix.png) ### Model Explainability (Grad-CAM) To ensure the model is focusing on relevant facial features, Grad-CAM was used. The heatmaps confirm that the model's predictions are primarily based on the driver's eyes, mouth, and head position, which are key indicators of drowsiness. ![Grad-CAM Visualization](output_grad_cam.jpg) ## Intended Use and Limitations This model is intended as a proof-of-concept for driver safety systems and academic research. It should not be used as the sole mechanism for preventing accidents in a production environment without further rigorous testing. Real-world performance may vary based on: - Lighting conditions (especially at night). - Camera angles and distance. - Occlusions (e.g., sunglasses, hats, hands on face). - Individual differences not represented in the training data. *This model card is based on the training notebook [`MobileViT_Drowsiness.ipynb`](https://github.com/mosesab/MobileViT-Drowsiness-Detection/blob/main/MobileViT_Drowsiness.ipynb).*