Added Transfer Learning example

#2
Files changed (1) hide show
  1. README.md +142 -33
README.md CHANGED
@@ -51,52 +51,161 @@ Further details are available in the corresponding [**paper**](https://huggingfa
51
  ### Usage
52
 
53
  ```python
54
- import torch
55
- import torch.nn as nn
56
- from transformers import AutoModelForAudioClassification, Wav2Vec2FeatureExtractor
57
 
58
 
59
- # CONFIG and MODEL SETUP
60
- model_name = 'amiriparian/ExHuBERT'
61
- feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/hubert-base-ls960")
62
- model = AutoModelForAudioClassification.from_pretrained(model_name, trust_remote_code=True,
63
- revision="b158d45ed8578432468f3ab8d46cbe5974380812")
64
 
65
- # Freezing half of the encoder for further transfer learning
66
- model.freeze_og_encoder()
67
 
68
- sampling_rate = 16000
69
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
70
- model = model.to(device)
71
 
72
 
73
 
74
- # Example application from a local audiofile
75
- import numpy as np
76
- import librosa
77
- import torch.nn.functional as F
78
- # Sample taken from the Toronto emotional speech set (TESS) https://tspace.library.utoronto.ca/handle/1807/24487
79
- waveform, sr_wav = librosa.load("YAF_date_angry.wav")
80
- # Max Padding to 3 Seconds at 16k sampling rate for the best results
81
- waveform = feature_extractor(waveform, sampling_rate=sampling_rate,padding = 'max_length',max_length = 48000)
82
- waveform = waveform['input_values'][0]
83
- waveform = waveform.reshape(1, -1)
84
- waveform = torch.from_numpy(waveform).to(device)
85
- with torch.no_grad():
86
- output = model(waveform)
87
- output = F.softmax(output.logits, dim = 1)
88
- output = output.detach().cpu().numpy().round(2)
89
- print(output)
90
 
91
- # [[0. 0. 0. 1. 0. 0.]]
92
- # Low | High Arousal
93
- # Neg. Neut. Pos. | Neg. Neut. Pos Valence
94
- # Disgust, Neutral, Kind| Anger, Surprise, Joy Example emotions
95
 
96
 
97
 
98
  ```
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  ### Citation Info
101
 
102
 
 
51
  ### Usage
52
 
53
  ```python
54
+ import torch
55
+ import torch.nn as nn
56
+ from transformers import AutoModelForAudioClassification, Wav2Vec2FeatureExtractor
57
 
58
 
59
+ # CONFIG and MODEL SETUP
60
+ model_name = 'amiriparian/ExHuBERT'
61
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/hubert-base-ls960")
62
+ model = AutoModelForAudioClassification.from_pretrained(model_name, trust_remote_code=True,
63
+ revision="b158d45ed8578432468f3ab8d46cbe5974380812")
64
 
65
+ # Freezing half of the encoder for further transfer learning
66
+ model.freeze_og_encoder()
67
 
68
+ sampling_rate = 16000
69
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
70
+ model = model.to(device)
71
 
72
 
73
 
74
+ # Example application from a local audiofile
75
+ import numpy as np
76
+ import librosa
77
+ import torch.nn.functional as F
78
+ # Sample taken from the Toronto emotional speech set (TESS) https://tspace.library.utoronto.ca/handle/1807/24487
79
+ waveform, sr_wav = librosa.load("YAF_date_angry.wav")
80
+ # Max Padding to 3 Seconds at 16k sampling rate for the best results
81
+ waveform = feature_extractor(waveform, sampling_rate=sampling_rate,padding = 'max_length',max_length = 48000)
82
+ waveform = waveform['input_values'][0]
83
+ waveform = waveform.reshape(1, -1)
84
+ waveform = torch.from_numpy(waveform).to(device)
85
+ with torch.no_grad():
86
+ output = model(waveform)
87
+ output = F.softmax(output.logits, dim = 1)
88
+ output = output.detach().cpu().numpy().round(2)
89
+ print(output)
90
 
91
+ # [[0. 0. 0. 1. 0. 0.]]
92
+ # Low | High Arousal
93
+ # Neg. Neut. Pos. | Neg. Neut. Pos Valence
94
+ # Disgust, Neutral, Kind| Anger, Surprise, Joy Example emotions
95
 
96
 
97
 
98
  ```
99
 
100
+ ### Example of How to Train the Model for Transfer Learning
101
+ The datasets used for showcasing are EmoDB and IEMOCAP from the HuggingFace Hub. As noted above, the model has seen both datasets before.
102
+
103
+ ```python
104
+
105
+ import pandas as pd
106
+ import torch
107
+ import torch.nn as nn
108
+ from torch.utils.data import Dataset, DataLoader
109
+ import librosa
110
+ import io
111
+ from transformers import AutoModelForAudioClassification, Wav2Vec2FeatureExtractor
112
+
113
+ # CONFIG and MODEL SETUP
114
+ model_name = 'amiriparian/ExHuBERT'
115
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/hubert-base-ls960")
116
+ model = AutoModelForAudioClassification.from_pretrained(model_name, trust_remote_code=True,
117
+ revision="b158d45ed8578432468f3ab8d46cbe5974380812")
118
+
119
+ # Replacing Classifier layer
120
+ model.classifier = nn.Linear(in_features=256, out_features=7)
121
+ # Freezing the original encoder layers and feature encoder (as in the paper) for further transfer learning
122
+ model.freeze_og_encoder()
123
+ model.freeze_feature_encoder()
124
+ model.train()
125
+
126
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
127
+ model = model.to(device)
128
+
129
+ # Define a custom dataset class
130
+ class EmotionDataset(Dataset):
131
+ def __init__(self, dataframe, feature_extractor, max_length):
132
+ self.dataframe = dataframe
133
+ self.feature_extractor = feature_extractor
134
+ self.max_length = max_length
135
+
136
+ def __len__(self):
137
+ return len(self.dataframe)
138
+
139
+ def __getitem__(self, idx):
140
+ row = self.dataframe.iloc[idx]
141
+ # emotion = torch.tensor(row['label'], dtype=torch.int64) # For the IEMOCAP example
142
+ emotion = torch.tensor(row['emotion'], dtype=torch.int64) # EmoDB specific
143
+
144
+ # Decode audio bytes from the Huggingface dataset with librosa
145
+ audio_bytes = row['audio']['bytes']
146
+ audio_buffer = io.BytesIO(audio_bytes)
147
+ audio_data, samplerate = librosa.load(audio_buffer, sr=16000)
148
+
149
+ # Use the feature extractor to preprocess the audio. Padding/Truncating to 3 seconds gives better results
150
+ audio_features = self.feature_extractor(audio_data, sampling_rate=16000, return_tensors="pt", padding="max_length",
151
+ truncation=True, max_length=self.max_length)
152
+
153
+ audio = audio_features['input_values'].squeeze(0)
154
+ return audio, emotion
155
+
156
+ # Load your DataFrame. Samples are shown for EmoDB and IEMOCAP from the Huggingface Hub
157
+ df = pd.read_parquet("hf://datasets/renumics/emodb/data/train-00000-of-00001-cf0d4b1ae18136ff.parquet")
158
+ # splits = {'session1': 'data/session1-00000-of-00001-04e11ca668d90573.parquet', 'session2': 'data/session2-00000-of-00001-f6132100b374cb18.parquet', 'session3': 'data/session3-00000-of-00001-6e102fcb5c1126b4.parquet', 'session4': 'data/session4-00000-of-00001-e39531a7c694b50d.parquet', 'session5': 'data/session5-00000-of-00001-03769060403172ce.parquet'}
159
+ # df = pd.read_parquet("hf://datasets/Zahra99/IEMOCAP_Audio/" + splits["session1"])
160
+
161
+ # Dataset and DataLoader
162
+ dataset = EmotionDataset(df, feature_extractor, max_length=3 * 16000)
163
+ dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
164
+
165
+ # Training setup
166
+ criterion = nn.CrossEntropyLoss()
167
+ lr = 1e-5
168
+ non_frozen_parameters = [p for p in model.parameters() if p.requires_grad]
169
+ optim = torch.optim.AdamW(non_frozen_parameters, lr=lr, betas=(0.9, 0.999), eps=1e-08)
170
+
171
+ # Function to calculate accuracy
172
+ def calculate_accuracy(outputs, targets):
173
+ _, predicted = torch.max(outputs, 1)
174
+ correct = (predicted == targets).sum().item()
175
+ return correct / targets.size(0)
176
+
177
+ # Training loop
178
+ num_epochs = 3
179
+ for epoch in range(num_epochs):
180
+ model.train()
181
+ total_loss = 0.0
182
+ total_correct = 0
183
+ total_samples = 0
184
+ for batch_idx, (inputs, targets) in enumerate(dataloader):
185
+ inputs, targets = inputs.to(device), targets.to(device)
186
+
187
+ optim.zero_grad()
188
+ outputs = model(inputs).logits
189
+ loss = criterion(outputs, targets)
190
+ loss.backward()
191
+ optim.step()
192
+
193
+ total_loss += loss.item()
194
+ total_correct += (outputs.argmax(1) == targets).sum().item()
195
+ total_samples += targets.size(0)
196
+
197
+ epoch_loss = total_loss / len(dataloader)
198
+ epoch_accuracy = total_correct / total_samples
199
+ print(f'Epoch [{epoch + 1}/{num_epochs}], Average Loss: {epoch_loss:.4f}, Average Accuracy: {epoch_accuracy:.4f}')
200
+
201
+ # Example outputs:
202
+ # Epoch [3/3], Average Loss: 0.4572, Average Accuracy: 0.8249 for IEMOCAP
203
+ # Epoch [3/3], Average Loss: 0.1511, Average Accuracy: 0.9850 for EmoDB
204
+
205
+
206
+ ```
207
+
208
+
209
  ### Citation Info
210
 
211