Whisper Audio Classification Model

A fine-tuned Whisper model for multi-task audio classification, specifically trained to classify English accents (23 classes) and speaker gender (2 classes) from speech audio.

🎯 Model Overview

This model uses OpenAI's Whisper encoder as a feature extractor with custom classification heads for:

  • Accent Classification: Identifies 23 different English accents
  • Gender Classification: Classifies speaker as male or female

Model Architecture

  • Base Model: openai/whisper-small.en
  • Encoder: Frozen Whisper encoder (for feature extraction)
  • Classification Heads: Custom neural networks with dropout for robust predictions
  • Multi-task Learning: Jointly trained on both accent and gender classification

πŸš€ Quick Start

Prerequisites

pip install torch transformers datasets numpy scikit-learn

Basic Usage

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import WhisperFeatureExtractor, WhisperModel
import numpy as np

# Define the model class (same as training)
class WhisperClassifier(nn.Module):
    def __init__(self, model_name="openai/whisper-small.en", num_accent_classes=23, num_gender_classes=2, 
                 freeze_encoder=True, dropout_rate=0.3):
        super().__init__()
        
        self.whisper = WhisperModel.from_pretrained(model_name)
        
        if freeze_encoder:
            for param in self.whisper.encoder.parameters():
                param.requires_grad = False
                
        self.hidden_size = self.whisper.config.d_model
        self.dropout = nn.Dropout(dropout_rate)
        
        # Accent classification head
        self.accent_classifier = nn.Sequential(
            nn.Linear(self.hidden_size, 512),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(256, num_accent_classes)
        )
        
        # Gender classification head
        self.gender_classifier = nn.Sequential(
            nn.Linear(self.hidden_size, 256),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(128, num_gender_classes)
        )
        
        self.num_accent_classes = num_accent_classes
        self.num_gender_classes = num_gender_classes
        
    def forward(self, input_features, accent_labels=None, gender_labels=None):
        encoder_outputs = self.whisper.encoder(input_features)
        hidden_states = encoder_outputs.last_hidden_state
        pooled_output = hidden_states.mean(dim=1)
        pooled_output = self.dropout(pooled_output)
        
        accent_logits = self.accent_classifier(pooled_output)
        gender_logits = self.gender_classifier(pooled_output)
        
        return {
            'accent_logits': accent_logits,
            'gender_logits': gender_logits,
        }

# Load the trained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = WhisperClassifier()

# Load the trained weights
model.load_state_dict(torch.load("./model_step1000.safetensors", map_location=device))
model.to(device)
model.eval()

# Initialize feature extractor
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small.en")

Making Predictions

def predict_audio(audio_file_path, model, feature_extractor, device):
    """
    Predict accent and gender from an audio file
    
    Args:
        audio_file_path: Path to audio file (.wav, .mp3, etc.)
        model: Trained WhisperClassifier model
        feature_extractor: Whisper feature extractor
        device: torch device (cuda/cpu)
    
    Returns:
        Dictionary with predictions and confidence scores
    """
    import librosa
    
    # Load audio file
    audio, sr = librosa.load(audio_file_path, sr=16000, mono=True)
    
    # Extract features
    inputs = feature_extractor(
        audio, 
        sampling_rate=sr, 
        return_tensors="pt"
    )
    
    # Move to device
    input_features = inputs.input_features.to(device)
    
    # Get predictions
    with torch.no_grad():
        outputs = model(input_features=input_features)
        
        # Get probabilities
        accent_probs = F.softmax(outputs["accent_logits"], dim=-1)
        gender_probs = F.softmax(outputs["gender_logits"], dim=-1)
        
        # Get predictions
        accent_pred = torch.argmax(accent_probs, dim=-1).item()
        gender_pred = torch.argmax(gender_probs, dim=-1).item()
        
        # Get confidence scores
        accent_confidence = accent_probs[0, accent_pred].item()
        gender_confidence = gender_probs[0, gender_pred].item()
    
    # Map predictions to labels
    accent_names = [
        'african', 'australia', 'bermuda', 'canada', 'england', 'hongkong', 
        'indian', 'ireland', 'malaysia', 'newzealand', 'philippines', 
        'scotland', 'singapore', 'southafrica', 'us', 'wales'
        # Add all 23 accent names based on your dataset
    ]
    
    accent_name = accent_names[accent_pred] if accent_pred < len(accent_names) else f"accent_{accent_pred}"
    gender_name = "male" if gender_pred == 0 else "female"
    
    return {
        'accent': accent_name,
        'accent_confidence': accent_confidence,
        'gender': gender_name,
        'gender_confidence': gender_confidence
    }

# Example usage
result = predict_audio("path/to/your/audio.wav", model, feature_extractor, device)
print(f"Predicted Accent: {result['accent']} (confidence: {result['accent_confidence']:.3f})")
print(f"Predicted Gender: {result['gender']} (confidence: {result['gender_confidence']:.3f})")

Batch Predictions

def predict_batch(audio_files, model, feature_extractor, device, batch_size=8):
    """
    Predict accent and gender for multiple audio files
    """
    import librosa
    from torch.utils.data import DataLoader, Dataset
    
    class AudioDataset(Dataset):
        def __init__(self, audio_files):
            self.audio_files = audio_files
            
        def __len__(self):
            return len(self.audio_files)
            
        def __getitem__(self, idx):
            audio, sr = librosa.load(self.audio_files[idx], sr=16000, mono=True)
            inputs = feature_extractor(audio, sampling_rate=sr, return_tensors="pt")
            return inputs.input_features.squeeze(0)
    
    dataset = AudioDataset(audio_files)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    
    results = []
    model.eval()
    
    with torch.no_grad():
        for batch in dataloader:
            batch = batch.to(device)
            outputs = model(input_features=batch)
            
            accent_probs = F.softmax(outputs["accent_logits"], dim=-1)
            gender_probs = F.softmax(outputs["gender_logits"], dim=-1)
            
            accent_preds = torch.argmax(accent_probs, dim=-1)
            gender_preds = torch.argmax(gender_probs, dim=-1)
            
            for i in range(len(batch)):
                results.append({
                    'accent_id': accent_preds[i].item(),
                    'accent_confidence': accent_probs[i, accent_preds[i]].item(),
                    'gender_id': gender_preds[i].item(),
                    'gender_confidence': gender_probs[i, gender_preds[i]].item(),
                })
    
    return results

πŸ“Š Model Performance

The model was trained on the English Accent Dataset with the following performance:

  • Accent Classification: Achieves high accuracy across 23 English accent varieties
  • Gender Classification: Robust binary classification for male/female voices
  • Multi-task Learning: Benefits from joint training on both tasks

Supported Accent Classes

The model can classify the following accent varieties:

  1. African
  2. Australian
  3. Bermuda
  4. Canadian
  5. England
  6. Hong Kong
  7. Indian
  8. Irish
  9. Malaysian
  10. New Zealand
  11. Philippines
  12. Scottish
  13. Singapore
  14. South African
  15. US American
  16. Welsh ... (and more, totaling 23 classes)

πŸ”§ Advanced Usage

Custom Audio Processing

def preprocess_custom_audio(audio_array, sample_rate, target_sr=16000):
    """
    Preprocess custom audio data
    """
    import librosa
    
    # Resample if needed
    if sample_rate != target_sr:
        audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=target_sr)
    
    # Ensure mono
    if len(audio_array.shape) > 1:
        audio_array = librosa.to_mono(audio_array)
    
    # Normalize
    audio_array = audio_array / np.max(np.abs(audio_array))
    
    return audio_array

Getting Top-K Predictions

def get_top_k_predictions(audio_file, model, feature_extractor, device, k=3):
    """
    Get top-k accent predictions with confidence scores
    """
    # ... (load and preprocess audio as above)
    
    with torch.no_grad():
        outputs = model(input_features=input_features)
        accent_probs = F.softmax(outputs["accent_logits"], dim=-1)
        
        # Get top-k predictions
        top_k_probs, top_k_indices = torch.topk(accent_probs, k, dim=-1)
        
        results = []
        for i in range(k):
            results.append({
                'accent_id': top_k_indices[0, i].item(),
                'confidence': top_k_probs[0, i].item()
            })
    
    return results

πŸ“‹ Requirements

  • Python 3.8+
  • PyTorch 1.9+
  • Transformers 4.20+
  • librosa (for audio loading)
  • numpy
  • scikit-learn (for evaluation metrics)

πŸ“„ License

This model is based on OpenAI's Whisper and follows the same licensing terms. Please check the original Whisper repository for license details.

πŸ™ Acknowledgments

  • OpenAI for the Whisper model
  • The English Accent Dataset creators
  • Hugging Face Transformers library

Note: This model is trained for research and educational purposes. Performance may vary on different audio qualities, recording conditions, and accent varieties not represented in the training data.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for nirmoh/accent-whisper

Finetuned
(2959)
this model

Dataset used to train nirmoh/accent-whisper