kunnark commited on
Commit
da8a8c5
1 Parent(s): c7708f8

First model commit.

Browse files
Files changed (1) hide show
  1. encoder_wav2vec_classifier.py +159 -0
encoder_wav2vec_classifier.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speechbrain.pretrained import Pretrained
2
+ import torch
3
+
4
+ class EncoderWav2vecClassifier(Pretrained):
5
+ """A ready-to-use class for utterance-level classification (e.g, speaker-id,
6
+ language-id, emotion recognition, keyword spotting, etc).
7
+ The class assumes that an self-supervised encoder like wav2vec2/hubert and a classifier model
8
+ are defined in the yaml file. If you want to
9
+ convert the predicted index into a corresponding text label, please
10
+ provide the path of the label_encoder in a variable called 'lab_encoder_file'
11
+ within the yaml.
12
+ The class can be used either to run only the encoder (encode_batch()) to
13
+ extract embeddings or to run a classification step (classify_batch()).
14
+ ```
15
+ Example
16
+ -------
17
+ >>> import torchaudio
18
+ >>> from speechbrain.pretrained import EncoderClassifier
19
+ >>> # Model is downloaded from the speechbrain HuggingFace repo
20
+ >>> tmpdir = getfixture("tmpdir")
21
+ >>> classifier = EncoderClassifier.from_hparams(
22
+ ... source="speechbrain/spkrec-ecapa-voxceleb",
23
+ ... savedir=tmpdir,
24
+ ... )
25
+ >>> # Compute embeddings
26
+ >>> signal, fs = torchaudio.load("samples/audio_samples/example1.wav")
27
+ >>> embeddings = classifier.encode_batch(signal)
28
+ >>> # Classification
29
+ >>> prediction = classifier .classify_batch(signal)
30
+ """
31
+
32
+ MODULES_NEEDED = ["wav2vec2", "attentive", "classifier"]
33
+
34
+ def __init__(self, *args, **kwargs):
35
+ super().__init__(*args, **kwargs)
36
+
37
+ def encode_batch(self, wavs, wav_lens=None, normalize=False):
38
+ """Encodes the input audio into a single vector embedding.
39
+ The waveforms should already be in the model's desired format.
40
+ You can call:
41
+ ``normalized = <this>.normalizer(signal, sample_rate)``
42
+ to get a correctly converted signal in most cases.
43
+ Arguments
44
+ ---------
45
+ wavs : torch.tensor
46
+ Batch of waveforms [batch, time, channels] or [batch, time]
47
+ depending on the model. Make sure the sample rate is fs=16000 Hz.
48
+ wav_lens : torch.tensor
49
+ Lengths of the waveforms relative to the longest one in the
50
+ batch, tensor of shape [batch]. The longest one should have
51
+ relative length 1.0 and others len(waveform) / max_length.
52
+ Used for ignoring padding.
53
+ normalize : bool
54
+ If True, it normalizes the embeddings with the statistics
55
+ contained in mean_var_norm_emb.
56
+ Returns
57
+ -------
58
+ torch.tensor
59
+ The encoded batch
60
+ """
61
+ # Manage single waveforms in input
62
+ if len(wavs.shape) == 1:
63
+ wavs = wavs.unsqueeze(0)
64
+
65
+ # Assign full length if wav_lens is not assigned
66
+ if wav_lens is None:
67
+ wav_lens = torch.ones(wavs.shape[0], device=self.device)
68
+
69
+ # Storing waveform in the specified device
70
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
71
+ wavs = wavs.float()
72
+
73
+ # Feature extraction and normalization
74
+ feats = self.modules.wav2vec2(wavs)
75
+ feats = feats.transpose(1, 2)
76
+
77
+ pooling = self.modules.attentive(feats, wav_lens) # channels = 1024
78
+ outputs = pooling.transpose(1, 2)
79
+ return outputs
80
+
81
+ def classify_batch(self, wavs, wav_lens=None):
82
+ """Performs classification on the top of the encoded features.
83
+ It returns the posterior probabilities, the index and, if the label
84
+ encoder is specified it also the text label.
85
+ Arguments
86
+ ---------
87
+ wavs : torch.tensor
88
+ Batch of waveforms [batch, time, channels] or [batch, time]
89
+ depending on the model. Make sure the sample rate is fs=16000 Hz.
90
+ wav_lens : torch.tensor
91
+ Lengths of the waveforms relative to the longest one in the
92
+ batch, tensor of shape [batch]. The longest one should have
93
+ relative length 1.0 and others len(waveform) / max_length.
94
+ Used for ignoring padding.
95
+ Returns
96
+ -------
97
+ out_prob
98
+ The log posterior probabilities of each class ([batch, N_class])
99
+ score:
100
+ It is the value of the log-posterior for the best class ([batch,])
101
+ index
102
+ The indexes of the best class ([batch,])
103
+ text_lab:
104
+ List with the text labels corresponding to the indexes.
105
+ (label encoder should be provided).
106
+ """
107
+ outputs = self.encode_batch(wavs, wav_lens)
108
+ outputs = self.modules.classifier(outputs)
109
+ out_prob = self.hparams.softmax(outputs)
110
+ score, index = torch.max(out_prob, dim=-1)
111
+ text_lab = self.hparams.label_encoder.decode_torch(index)
112
+ '''
113
+ outputs = self.modules.output_mlp(outputs)
114
+ out_prob = self.hparams.softmax(outputs)
115
+ score, index = torch.max(out_prob, dim=-1)
116
+ text_lab = self.hparams.label_encoder.decode_torch(index)
117
+ '''
118
+ return out_prob, score, index, text_lab
119
+
120
+ def classify_file(self, path):
121
+ """Classifies the given audiofile into the given set of labels.
122
+ Arguments
123
+ ---------
124
+ path : str
125
+ Path to audio file to classify.
126
+ Returns
127
+ -------
128
+ out_prob
129
+ The log posterior probabilities of each class ([batch, N_class])
130
+ score:
131
+ It is the value of the log-posterior for the best class ([batch,])
132
+ index
133
+ The indexes of the best class ([batch,])
134
+ text_lab:
135
+ List with the text labels corresponding to the indexes.
136
+ (label encoder should be provided).
137
+ """
138
+ waveform = self.load_audio(path)
139
+ # Fake a batch:
140
+ batch = waveform.unsqueeze(0)
141
+ rel_length = torch.tensor([1.0])
142
+ outputs = self.encode_batch(batch, rel_length)
143
+
144
+ outputs = self.modules.classifier(outputs)
145
+ # print("classify_outputs_0", outputs.shape)
146
+
147
+ out_prob = self.hparams.softmax(outputs)
148
+ # print("classify_out_1_softmax", out_prob)
149
+ score, index = torch.max(out_prob, dim=-1)
150
+ text_lab = self.hparams.label_encoder.decode_torch(index)
151
+ # print("classify_score_2", score)
152
+ # print("classify_index_3", index)
153
+ # print("classify_textlab_4", text_lab)
154
+ return out_prob, score, index, text_lab
155
+
156
+ def forward(self, wavs, wav_lens=None, normalize=False):
157
+ return self.encode_batch(
158
+ wavs=wavs, wav_lens=wav_lens, normalize=normalize
159
+ )