Whisper Audio Classification Model
A fine-tuned Whisper model for multi-task audio classification, specifically trained to classify English accents (23 classes) and speaker gender (2 classes) from speech audio.
π― Model Overview
This model uses OpenAI's Whisper encoder as a feature extractor with custom classification heads for:
- Accent Classification: Identifies 23 different English accents
- Gender Classification: Classifies speaker as male or female
Model Architecture
- Base Model:
openai/whisper-small.en
- Encoder: Frozen Whisper encoder (for feature extraction)
- Classification Heads: Custom neural networks with dropout for robust predictions
- Multi-task Learning: Jointly trained on both accent and gender classification
π Quick Start
Prerequisites
pip install torch transformers datasets numpy scikit-learn
Basic Usage
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import WhisperFeatureExtractor, WhisperModel
import numpy as np
# Define the model class (same as training)
class WhisperClassifier(nn.Module):
def __init__(self, model_name="openai/whisper-small.en", num_accent_classes=23, num_gender_classes=2,
freeze_encoder=True, dropout_rate=0.3):
super().__init__()
self.whisper = WhisperModel.from_pretrained(model_name)
if freeze_encoder:
for param in self.whisper.encoder.parameters():
param.requires_grad = False
self.hidden_size = self.whisper.config.d_model
self.dropout = nn.Dropout(dropout_rate)
# Accent classification head
self.accent_classifier = nn.Sequential(
nn.Linear(self.hidden_size, 512),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(256, num_accent_classes)
)
# Gender classification head
self.gender_classifier = nn.Sequential(
nn.Linear(self.hidden_size, 256),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(256, 128),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(128, num_gender_classes)
)
self.num_accent_classes = num_accent_classes
self.num_gender_classes = num_gender_classes
def forward(self, input_features, accent_labels=None, gender_labels=None):
encoder_outputs = self.whisper.encoder(input_features)
hidden_states = encoder_outputs.last_hidden_state
pooled_output = hidden_states.mean(dim=1)
pooled_output = self.dropout(pooled_output)
accent_logits = self.accent_classifier(pooled_output)
gender_logits = self.gender_classifier(pooled_output)
return {
'accent_logits': accent_logits,
'gender_logits': gender_logits,
}
# Load the trained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = WhisperClassifier()
# Load the trained weights
model.load_state_dict(torch.load("./model_step1000.safetensors", map_location=device))
model.to(device)
model.eval()
# Initialize feature extractor
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small.en")
Making Predictions
def predict_audio(audio_file_path, model, feature_extractor, device):
"""
Predict accent and gender from an audio file
Args:
audio_file_path: Path to audio file (.wav, .mp3, etc.)
model: Trained WhisperClassifier model
feature_extractor: Whisper feature extractor
device: torch device (cuda/cpu)
Returns:
Dictionary with predictions and confidence scores
"""
import librosa
# Load audio file
audio, sr = librosa.load(audio_file_path, sr=16000, mono=True)
# Extract features
inputs = feature_extractor(
audio,
sampling_rate=sr,
return_tensors="pt"
)
# Move to device
input_features = inputs.input_features.to(device)
# Get predictions
with torch.no_grad():
outputs = model(input_features=input_features)
# Get probabilities
accent_probs = F.softmax(outputs["accent_logits"], dim=-1)
gender_probs = F.softmax(outputs["gender_logits"], dim=-1)
# Get predictions
accent_pred = torch.argmax(accent_probs, dim=-1).item()
gender_pred = torch.argmax(gender_probs, dim=-1).item()
# Get confidence scores
accent_confidence = accent_probs[0, accent_pred].item()
gender_confidence = gender_probs[0, gender_pred].item()
# Map predictions to labels
accent_names = [
'african', 'australia', 'bermuda', 'canada', 'england', 'hongkong',
'indian', 'ireland', 'malaysia', 'newzealand', 'philippines',
'scotland', 'singapore', 'southafrica', 'us', 'wales'
# Add all 23 accent names based on your dataset
]
accent_name = accent_names[accent_pred] if accent_pred < len(accent_names) else f"accent_{accent_pred}"
gender_name = "male" if gender_pred == 0 else "female"
return {
'accent': accent_name,
'accent_confidence': accent_confidence,
'gender': gender_name,
'gender_confidence': gender_confidence
}
# Example usage
result = predict_audio("path/to/your/audio.wav", model, feature_extractor, device)
print(f"Predicted Accent: {result['accent']} (confidence: {result['accent_confidence']:.3f})")
print(f"Predicted Gender: {result['gender']} (confidence: {result['gender_confidence']:.3f})")
Batch Predictions
def predict_batch(audio_files, model, feature_extractor, device, batch_size=8):
"""
Predict accent and gender for multiple audio files
"""
import librosa
from torch.utils.data import DataLoader, Dataset
class AudioDataset(Dataset):
def __init__(self, audio_files):
self.audio_files = audio_files
def __len__(self):
return len(self.audio_files)
def __getitem__(self, idx):
audio, sr = librosa.load(self.audio_files[idx], sr=16000, mono=True)
inputs = feature_extractor(audio, sampling_rate=sr, return_tensors="pt")
return inputs.input_features.squeeze(0)
dataset = AudioDataset(audio_files)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
results = []
model.eval()
with torch.no_grad():
for batch in dataloader:
batch = batch.to(device)
outputs = model(input_features=batch)
accent_probs = F.softmax(outputs["accent_logits"], dim=-1)
gender_probs = F.softmax(outputs["gender_logits"], dim=-1)
accent_preds = torch.argmax(accent_probs, dim=-1)
gender_preds = torch.argmax(gender_probs, dim=-1)
for i in range(len(batch)):
results.append({
'accent_id': accent_preds[i].item(),
'accent_confidence': accent_probs[i, accent_preds[i]].item(),
'gender_id': gender_preds[i].item(),
'gender_confidence': gender_probs[i, gender_preds[i]].item(),
})
return results
π Model Performance
The model was trained on the English Accent Dataset with the following performance:
- Accent Classification: Achieves high accuracy across 23 English accent varieties
- Gender Classification: Robust binary classification for male/female voices
- Multi-task Learning: Benefits from joint training on both tasks
Supported Accent Classes
The model can classify the following accent varieties:
- African
- Australian
- Bermuda
- Canadian
- England
- Hong Kong
- Indian
- Irish
- Malaysian
- New Zealand
- Philippines
- Scottish
- Singapore
- South African
- US American
- Welsh ... (and more, totaling 23 classes)
π§ Advanced Usage
Custom Audio Processing
def preprocess_custom_audio(audio_array, sample_rate, target_sr=16000):
"""
Preprocess custom audio data
"""
import librosa
# Resample if needed
if sample_rate != target_sr:
audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=target_sr)
# Ensure mono
if len(audio_array.shape) > 1:
audio_array = librosa.to_mono(audio_array)
# Normalize
audio_array = audio_array / np.max(np.abs(audio_array))
return audio_array
Getting Top-K Predictions
def get_top_k_predictions(audio_file, model, feature_extractor, device, k=3):
"""
Get top-k accent predictions with confidence scores
"""
# ... (load and preprocess audio as above)
with torch.no_grad():
outputs = model(input_features=input_features)
accent_probs = F.softmax(outputs["accent_logits"], dim=-1)
# Get top-k predictions
top_k_probs, top_k_indices = torch.topk(accent_probs, k, dim=-1)
results = []
for i in range(k):
results.append({
'accent_id': top_k_indices[0, i].item(),
'confidence': top_k_probs[0, i].item()
})
return results
π Requirements
- Python 3.8+
- PyTorch 1.9+
- Transformers 4.20+
- librosa (for audio loading)
- numpy
- scikit-learn (for evaluation metrics)
π License
This model is based on OpenAI's Whisper and follows the same licensing terms. Please check the original Whisper repository for license details.
π Acknowledgments
- OpenAI for the Whisper model
- The English Accent Dataset creators
- Hugging Face Transformers library
Note: This model is trained for research and educational purposes. Performance may vary on different audio qualities, recording conditions, and accent varieties not represented in the training data.
Model tree for nirmoh/accent-whisper
Base model
openai/whisper-small