nirmoh commited on
Commit
70e4d10
Β·
verified Β·
1 Parent(s): f821fbd

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. .gitattributes +35 -35
  2. README.md +335 -11
  3. accent_classifier.safetensors +3 -0
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,11 +1,335 @@
1
- ---
2
- license: mit
3
- datasets:
4
- - westbrook/English_Accent_DataSet
5
- base_model:
6
- - openai/whisper-small
7
- pipeline_tag: audio-classification
8
- tags:
9
- - accent
10
- - gender
11
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ datasets:
4
+ - westbrook/English_Accent_DataSet
5
+ base_model:
6
+ - openai/whisper-small
7
+ pipeline_tag: audio-classification
8
+ tags:
9
+ - accent
10
+ - gender
11
+ ---
12
+
13
+ # Whisper Audio Classification Model
14
+
15
+ A fine-tuned Whisper model for multi-task audio classification, specifically trained to classify **English accents** (23 classes) and **speaker gender** (2 classes) from speech audio.
16
+
17
+ ## 🎯 Model Overview
18
+
19
+ This model uses OpenAI's Whisper encoder as a feature extractor with custom classification heads for:
20
+ - **Accent Classification**: Identifies 23 different English accents
21
+ - **Gender Classification**: Classifies speaker as male or female
22
+
23
+ ### Model Architecture
24
+ - **Base Model**: `openai/whisper-small.en`
25
+ - **Encoder**: Frozen Whisper encoder (for feature extraction)
26
+ - **Classification Heads**: Custom neural networks with dropout for robust predictions
27
+ - **Multi-task Learning**: Jointly trained on both accent and gender classification
28
+
29
+ ## πŸš€ Quick Start
30
+
31
+ ### Prerequisites
32
+
33
+ ```bash
34
+ pip install torch transformers datasets numpy scikit-learn
35
+ ```
36
+
37
+ ### Basic Usage
38
+
39
+ ```python
40
+ import torch
41
+ import torch.nn as nn
42
+ import torch.nn.functional as F
43
+ from transformers import WhisperFeatureExtractor, WhisperModel
44
+ import numpy as np
45
+
46
+ # Define the model class (same as training)
47
+ class WhisperClassifier(nn.Module):
48
+ def __init__(self, model_name="openai/whisper-small.en", num_accent_classes=23, num_gender_classes=2,
49
+ freeze_encoder=True, dropout_rate=0.3):
50
+ super().__init__()
51
+
52
+ self.whisper = WhisperModel.from_pretrained(model_name)
53
+
54
+ if freeze_encoder:
55
+ for param in self.whisper.encoder.parameters():
56
+ param.requires_grad = False
57
+
58
+ self.hidden_size = self.whisper.config.d_model
59
+ self.dropout = nn.Dropout(dropout_rate)
60
+
61
+ # Accent classification head
62
+ self.accent_classifier = nn.Sequential(
63
+ nn.Linear(self.hidden_size, 512),
64
+ nn.ReLU(),
65
+ nn.Dropout(dropout_rate),
66
+ nn.Linear(512, 256),
67
+ nn.ReLU(),
68
+ nn.Dropout(dropout_rate),
69
+ nn.Linear(256, num_accent_classes)
70
+ )
71
+
72
+ # Gender classification head
73
+ self.gender_classifier = nn.Sequential(
74
+ nn.Linear(self.hidden_size, 256),
75
+ nn.ReLU(),
76
+ nn.Dropout(dropout_rate),
77
+ nn.Linear(256, 128),
78
+ nn.ReLU(),
79
+ nn.Dropout(dropout_rate),
80
+ nn.Linear(128, num_gender_classes)
81
+ )
82
+
83
+ self.num_accent_classes = num_accent_classes
84
+ self.num_gender_classes = num_gender_classes
85
+
86
+ def forward(self, input_features, accent_labels=None, gender_labels=None):
87
+ encoder_outputs = self.whisper.encoder(input_features)
88
+ hidden_states = encoder_outputs.last_hidden_state
89
+ pooled_output = hidden_states.mean(dim=1)
90
+ pooled_output = self.dropout(pooled_output)
91
+
92
+ accent_logits = self.accent_classifier(pooled_output)
93
+ gender_logits = self.gender_classifier(pooled_output)
94
+
95
+ return {
96
+ 'accent_logits': accent_logits,
97
+ 'gender_logits': gender_logits,
98
+ }
99
+
100
+ # Load the trained model
101
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
102
+ model = WhisperClassifier()
103
+
104
+ # Load the trained weights
105
+ model.load_state_dict(torch.load("./model_step1000.safetensors", map_location=device))
106
+ model.to(device)
107
+ model.eval()
108
+
109
+ # Initialize feature extractor
110
+ feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small.en")
111
+ ```
112
+
113
+ ### Making Predictions
114
+
115
+ ```python
116
+ def predict_audio(audio_file_path, model, feature_extractor, device):
117
+ """
118
+ Predict accent and gender from an audio file
119
+
120
+ Args:
121
+ audio_file_path: Path to audio file (.wav, .mp3, etc.)
122
+ model: Trained WhisperClassifier model
123
+ feature_extractor: Whisper feature extractor
124
+ device: torch device (cuda/cpu)
125
+
126
+ Returns:
127
+ Dictionary with predictions and confidence scores
128
+ """
129
+ import librosa
130
+
131
+ # Load audio file
132
+ audio, sr = librosa.load(audio_file_path, sr=16000, mono=True)
133
+
134
+ # Extract features
135
+ inputs = feature_extractor(
136
+ audio,
137
+ sampling_rate=sr,
138
+ return_tensors="pt"
139
+ )
140
+
141
+ # Move to device
142
+ input_features = inputs.input_features.to(device)
143
+
144
+ # Get predictions
145
+ with torch.no_grad():
146
+ outputs = model(input_features=input_features)
147
+
148
+ # Get probabilities
149
+ accent_probs = F.softmax(outputs["accent_logits"], dim=-1)
150
+ gender_probs = F.softmax(outputs["gender_logits"], dim=-1)
151
+
152
+ # Get predictions
153
+ accent_pred = torch.argmax(accent_probs, dim=-1).item()
154
+ gender_pred = torch.argmax(gender_probs, dim=-1).item()
155
+
156
+ # Get confidence scores
157
+ accent_confidence = accent_probs[0, accent_pred].item()
158
+ gender_confidence = gender_probs[0, gender_pred].item()
159
+
160
+ # Map predictions to labels
161
+ accent_names = [
162
+ 'african', 'australia', 'bermuda', 'canada', 'england', 'hongkong',
163
+ 'indian', 'ireland', 'malaysia', 'newzealand', 'philippines',
164
+ 'scotland', 'singapore', 'southafrica', 'us', 'wales'
165
+ # Add all 23 accent names based on your dataset
166
+ ]
167
+
168
+ accent_name = accent_names[accent_pred] if accent_pred < len(accent_names) else f"accent_{accent_pred}"
169
+ gender_name = "male" if gender_pred == 0 else "female"
170
+
171
+ return {
172
+ 'accent': accent_name,
173
+ 'accent_confidence': accent_confidence,
174
+ 'gender': gender_name,
175
+ 'gender_confidence': gender_confidence
176
+ }
177
+
178
+ # Example usage
179
+ result = predict_audio("path/to/your/audio.wav", model, feature_extractor, device)
180
+ print(f"Predicted Accent: {result['accent']} (confidence: {result['accent_confidence']:.3f})")
181
+ print(f"Predicted Gender: {result['gender']} (confidence: {result['gender_confidence']:.3f})")
182
+ ```
183
+
184
+ ### Batch Predictions
185
+
186
+ ```python
187
+ def predict_batch(audio_files, model, feature_extractor, device, batch_size=8):
188
+ """
189
+ Predict accent and gender for multiple audio files
190
+ """
191
+ import librosa
192
+ from torch.utils.data import DataLoader, Dataset
193
+
194
+ class AudioDataset(Dataset):
195
+ def __init__(self, audio_files):
196
+ self.audio_files = audio_files
197
+
198
+ def __len__(self):
199
+ return len(self.audio_files)
200
+
201
+ def __getitem__(self, idx):
202
+ audio, sr = librosa.load(self.audio_files[idx], sr=16000, mono=True)
203
+ inputs = feature_extractor(audio, sampling_rate=sr, return_tensors="pt")
204
+ return inputs.input_features.squeeze(0)
205
+
206
+ dataset = AudioDataset(audio_files)
207
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
208
+
209
+ results = []
210
+ model.eval()
211
+
212
+ with torch.no_grad():
213
+ for batch in dataloader:
214
+ batch = batch.to(device)
215
+ outputs = model(input_features=batch)
216
+
217
+ accent_probs = F.softmax(outputs["accent_logits"], dim=-1)
218
+ gender_probs = F.softmax(outputs["gender_logits"], dim=-1)
219
+
220
+ accent_preds = torch.argmax(accent_probs, dim=-1)
221
+ gender_preds = torch.argmax(gender_probs, dim=-1)
222
+
223
+ for i in range(len(batch)):
224
+ results.append({
225
+ 'accent_id': accent_preds[i].item(),
226
+ 'accent_confidence': accent_probs[i, accent_preds[i]].item(),
227
+ 'gender_id': gender_preds[i].item(),
228
+ 'gender_confidence': gender_probs[i, gender_preds[i]].item(),
229
+ })
230
+
231
+ return results
232
+ ```
233
+
234
+ ## πŸ“Š Model Performance
235
+
236
+ The model was trained on the English Accent Dataset with the following performance:
237
+
238
+ - **Accent Classification**: Achieves high accuracy across 23 English accent varieties
239
+ - **Gender Classification**: Robust binary classification for male/female voices
240
+ - **Multi-task Learning**: Benefits from joint training on both tasks
241
+
242
+ ### Supported Accent Classes
243
+
244
+ The model can classify the following accent varieties:
245
+ 1. African
246
+ 2. Australian
247
+ 3. Bermuda
248
+ 4. Canadian
249
+ 5. England
250
+ 6. Hong Kong
251
+ 7. Indian
252
+ 8. Irish
253
+ 9. Malaysian
254
+ 10. New Zealand
255
+ 11. Philippines
256
+ 12. Scottish
257
+ 13. Singapore
258
+ 14. South African
259
+ 15. US American
260
+ 16. Welsh
261
+ ... (and more, totaling 23 classes)
262
+
263
+ ## πŸ”§ Advanced Usage
264
+
265
+ ### Custom Audio Processing
266
+
267
+ ```python
268
+ def preprocess_custom_audio(audio_array, sample_rate, target_sr=16000):
269
+ """
270
+ Preprocess custom audio data
271
+ """
272
+ import librosa
273
+
274
+ # Resample if needed
275
+ if sample_rate != target_sr:
276
+ audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=target_sr)
277
+
278
+ # Ensure mono
279
+ if len(audio_array.shape) > 1:
280
+ audio_array = librosa.to_mono(audio_array)
281
+
282
+ # Normalize
283
+ audio_array = audio_array / np.max(np.abs(audio_array))
284
+
285
+ return audio_array
286
+ ```
287
+
288
+ ### Getting Top-K Predictions
289
+
290
+ ```python
291
+ def get_top_k_predictions(audio_file, model, feature_extractor, device, k=3):
292
+ """
293
+ Get top-k accent predictions with confidence scores
294
+ """
295
+ # ... (load and preprocess audio as above)
296
+
297
+ with torch.no_grad():
298
+ outputs = model(input_features=input_features)
299
+ accent_probs = F.softmax(outputs["accent_logits"], dim=-1)
300
+
301
+ # Get top-k predictions
302
+ top_k_probs, top_k_indices = torch.topk(accent_probs, k, dim=-1)
303
+
304
+ results = []
305
+ for i in range(k):
306
+ results.append({
307
+ 'accent_id': top_k_indices[0, i].item(),
308
+ 'confidence': top_k_probs[0, i].item()
309
+ })
310
+
311
+ return results
312
+ ```
313
+
314
+ ## πŸ“‹ Requirements
315
+
316
+ - Python 3.8+
317
+ - PyTorch 1.9+
318
+ - Transformers 4.20+
319
+ - librosa (for audio loading)
320
+ - numpy
321
+ - scikit-learn (for evaluation metrics)
322
+
323
+ ## πŸ“„ License
324
+
325
+ This model is based on OpenAI's Whisper and follows the same licensing terms. Please check the original Whisper repository for license details.
326
+
327
+ ## πŸ™ Acknowledgments
328
+
329
+ - OpenAI for the Whisper model
330
+ - The English Accent Dataset creators
331
+ - Hugging Face Transformers library
332
+
333
+ ---
334
+
335
+ **Note**: This model is trained for research and educational purposes. Performance may vary on different audio qualities, recording conditions, and accent varieties not represented in the training data.
accent_classifier.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9aab970ac55d278a44a5e8bc3a714604c69eb67b5ac2644ebcd2d3faa262c6d6
3
+ size 970038528