license: mit
datasets:
- westbrook/English_Accent_DataSet
base_model:
- openai/whisper-small
pipeline_tag: audio-classification
tags:
- accent
- gender
Whisper Audio Classification Model
A fine-tuned Whisper model for multi-task audio classification, specifically trained to classify English accents (23 classes) and speaker gender (2 classes) from speech audio.
π― Model Overview
This model uses OpenAI's Whisper encoder as a feature extractor with custom classification heads for:
- Accent Classification: Identifies 23 different English accents
- Gender Classification: Classifies speaker as male or female
Model Architecture
- Base Model:
openai/whisper-small.en
- Encoder: Frozen Whisper encoder (for feature extraction)
- Classification Heads: Custom neural networks with dropout for robust predictions
- Multi-task Learning: Jointly trained on both accent and gender classification
π Quick Start
Prerequisites
pip install torch transformers datasets numpy scikit-learn
Basic Usage
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import WhisperFeatureExtractor, WhisperModel
import numpy as np
# Define the model class (same as training)
class WhisperClassifier(nn.Module):
def __init__(self, model_name="openai/whisper-small.en", num_accent_classes=23, num_gender_classes=2,
freeze_encoder=True, dropout_rate=0.3):
super().__init__()
self.whisper = WhisperModel.from_pretrained(model_name)
if freeze_encoder:
for param in self.whisper.encoder.parameters():
param.requires_grad = False
self.hidden_size = self.whisper.config.d_model
self.dropout = nn.Dropout(dropout_rate)
# Accent classification head
self.accent_classifier = nn.Sequential(
nn.Linear(self.hidden_size, 512),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(256, num_accent_classes)
)
# Gender classification head
self.gender_classifier = nn.Sequential(
nn.Linear(self.hidden_size, 256),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(256, 128),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(128, num_gender_classes)
)
self.num_accent_classes = num_accent_classes
self.num_gender_classes = num_gender_classes
def forward(self, input_features, accent_labels=None, gender_labels=None):
encoder_outputs = self.whisper.encoder(input_features)
hidden_states = encoder_outputs.last_hidden_state
pooled_output = hidden_states.mean(dim=1)
pooled_output = self.dropout(pooled_output)
accent_logits = self.accent_classifier(pooled_output)
gender_logits = self.gender_classifier(pooled_output)
return {
'accent_logits': accent_logits,
'gender_logits': gender_logits,
}
# Load the trained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = WhisperClassifier()
# Load the trained weights
model.load_state_dict(torch.load("./model_step1000.safetensors", map_location=device))
model.to(device)
model.eval()
# Initialize feature extractor
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small.en")
Making Predictions
def predict_audio(audio_file_path, model, feature_extractor, device):
"""
Predict accent and gender from an audio file
Args:
audio_file_path: Path to audio file (.wav, .mp3, etc.)
model: Trained WhisperClassifier model
feature_extractor: Whisper feature extractor
device: torch device (cuda/cpu)
Returns:
Dictionary with predictions and confidence scores
"""
import librosa
# Load audio file
audio, sr = librosa.load(audio_file_path, sr=16000, mono=True)
# Extract features
inputs = feature_extractor(
audio,
sampling_rate=sr,
return_tensors="pt"
)
# Move to device
input_features = inputs.input_features.to(device)
# Get predictions
with torch.no_grad():
outputs = model(input_features=input_features)
# Get probabilities
accent_probs = F.softmax(outputs["accent_logits"], dim=-1)
gender_probs = F.softmax(outputs["gender_logits"], dim=-1)
# Get predictions
accent_pred = torch.argmax(accent_probs, dim=-1).item()
gender_pred = torch.argmax(gender_probs, dim=-1).item()
# Get confidence scores
accent_confidence = accent_probs[0, accent_pred].item()
gender_confidence = gender_probs[0, gender_pred].item()
# Map predictions to labels
accent_names = [
'african', 'australia', 'bermuda', 'canada', 'england', 'hongkong',
'indian', 'ireland', 'malaysia', 'newzealand', 'philippines',
'scotland', 'singapore', 'southafrica', 'us', 'wales'
# Add all 23 accent names based on your dataset
]
accent_name = accent_names[accent_pred] if accent_pred < len(accent_names) else f"accent_{accent_pred}"
gender_name = "male" if gender_pred == 0 else "female"
return {
'accent': accent_name,
'accent_confidence': accent_confidence,
'gender': gender_name,
'gender_confidence': gender_confidence
}
# Example usage
result = predict_audio("path/to/your/audio.wav", model, feature_extractor, device)
print(f"Predicted Accent: {result['accent']} (confidence: {result['accent_confidence']:.3f})")
print(f"Predicted Gender: {result['gender']} (confidence: {result['gender_confidence']:.3f})")
Batch Predictions
def predict_batch(audio_files, model, feature_extractor, device, batch_size=8):
"""
Predict accent and gender for multiple audio files
"""
import librosa
from torch.utils.data import DataLoader, Dataset
class AudioDataset(Dataset):
def __init__(self, audio_files):
self.audio_files = audio_files
def __len__(self):
return len(self.audio_files)
def __getitem__(self, idx):
audio, sr = librosa.load(self.audio_files[idx], sr=16000, mono=True)
inputs = feature_extractor(audio, sampling_rate=sr, return_tensors="pt")
return inputs.input_features.squeeze(0)
dataset = AudioDataset(audio_files)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
results = []
model.eval()
with torch.no_grad():
for batch in dataloader:
batch = batch.to(device)
outputs = model(input_features=batch)
accent_probs = F.softmax(outputs["accent_logits"], dim=-1)
gender_probs = F.softmax(outputs["gender_logits"], dim=-1)
accent_preds = torch.argmax(accent_probs, dim=-1)
gender_preds = torch.argmax(gender_probs, dim=-1)
for i in range(len(batch)):
results.append({
'accent_id': accent_preds[i].item(),
'accent_confidence': accent_probs[i, accent_preds[i]].item(),
'gender_id': gender_preds[i].item(),
'gender_confidence': gender_probs[i, gender_preds[i]].item(),
})
return results
π Model Performance
The model was trained on the English Accent Dataset with the following performance:
- Accent Classification: Achieves high accuracy across 23 English accent varieties
- Gender Classification: Robust binary classification for male/female voices
- Multi-task Learning: Benefits from joint training on both tasks
Supported Accent Classes
The model can classify the following accent varieties:
- African
- Australian
- Bermuda
- Canadian
- England
- Hong Kong
- Indian
- Irish
- Malaysian
- New Zealand
- Philippines
- Scottish
- Singapore
- South African
- US American
- Welsh ... (and more, totaling 23 classes)
π§ Advanced Usage
Custom Audio Processing
def preprocess_custom_audio(audio_array, sample_rate, target_sr=16000):
"""
Preprocess custom audio data
"""
import librosa
# Resample if needed
if sample_rate != target_sr:
audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=target_sr)
# Ensure mono
if len(audio_array.shape) > 1:
audio_array = librosa.to_mono(audio_array)
# Normalize
audio_array = audio_array / np.max(np.abs(audio_array))
return audio_array
Getting Top-K Predictions
def get_top_k_predictions(audio_file, model, feature_extractor, device, k=3):
"""
Get top-k accent predictions with confidence scores
"""
# ... (load and preprocess audio as above)
with torch.no_grad():
outputs = model(input_features=input_features)
accent_probs = F.softmax(outputs["accent_logits"], dim=-1)
# Get top-k predictions
top_k_probs, top_k_indices = torch.topk(accent_probs, k, dim=-1)
results = []
for i in range(k):
results.append({
'accent_id': top_k_indices[0, i].item(),
'confidence': top_k_probs[0, i].item()
})
return results
π Requirements
- Python 3.8+
- PyTorch 1.9+
- Transformers 4.20+
- librosa (for audio loading)
- numpy
- scikit-learn (for evaluation metrics)
π License
This model is based on OpenAI's Whisper and follows the same licensing terms. Please check the original Whisper repository for license details.
π Acknowledgments
- OpenAI for the Whisper model
- The English Accent Dataset creators
- Hugging Face Transformers library
Note: This model is trained for research and educational purposes. Performance may vary on different audio qualities, recording conditions, and accent varieties not represented in the training data.