KELONMYOSA commited on
Commit
b39f025
·
1 Parent(s): 7260b2e

Upload Wav2Vec2ForSpeechClassification

Browse files
Files changed (3) hide show
  1. config.json +3 -0
  2. emotion_model.py +116 -0
  3. pytorch_model.bin +2 -2
config.json CHANGED
@@ -9,6 +9,9 @@
9
  "Wav2Vec2ForSpeechClassification"
10
  ],
11
  "attention_dropout": 0.1,
 
 
 
12
  "bos_token_id": 1,
13
  "classifier_proj_size": 256,
14
  "codevector_dim": 256,
 
9
  "Wav2Vec2ForSpeechClassification"
10
  ],
11
  "attention_dropout": 0.1,
12
+ "auto_map": {
13
+ "AutoModelForAudioClassification": "emotion_model.Wav2Vec2ForSpeechClassification"
14
+ },
15
  "bos_token_id": 1,
16
  "classifier_proj_size": 256,
17
  "codevector_dim": 256,
emotion_model.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Tuple, Optional
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
6
+ from transformers.file_utils import ModelOutput
7
+ from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2PreTrainedModel, Wav2Vec2Model
8
+
9
+
10
+ @dataclass
11
+ class SpeechClassifierOutput(ModelOutput):
12
+ loss: Optional[torch.FloatTensor] = None
13
+ logits: torch.FloatTensor = None
14
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
15
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
16
+
17
+
18
+ class Wav2Vec2ClassificationHead(nn.Module):
19
+ def __init__(self, config):
20
+ super().__init__()
21
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
22
+ self.dropout = nn.Dropout(config.final_dropout)
23
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
24
+
25
+ def forward(self, features, **kwargs):
26
+ x = features
27
+ x = self.dropout(x)
28
+ x = self.dense(x)
29
+ x = torch.tanh(x)
30
+ x = self.dropout(x)
31
+ x = self.out_proj(x)
32
+ return x
33
+
34
+
35
+ class Wav2Vec2ForSpeechClassification(Wav2Vec2PreTrainedModel):
36
+ def __init__(self, config):
37
+ super().__init__(config)
38
+ self.num_labels = config.num_labels
39
+ self.config = config
40
+
41
+ self.wav2vec2 = Wav2Vec2Model(config)
42
+ self.classifier = Wav2Vec2ClassificationHead(config)
43
+
44
+ self.init_weights()
45
+
46
+ def freeze_feature_extractor(self):
47
+ self.wav2vec2.feature_extractor._freeze_parameters()
48
+
49
+ def merged_strategy(
50
+ self,
51
+ hidden_states,
52
+ mode="mean"
53
+ ):
54
+ if mode == "mean":
55
+ outputs = torch.mean(hidden_states, dim=1)
56
+ elif mode == "sum":
57
+ outputs = torch.sum(hidden_states, dim=1)
58
+ elif mode == "max":
59
+ outputs = torch.max(hidden_states, dim=1)[0]
60
+ else:
61
+ raise Exception(
62
+ "The pooling method hasn't been defined! Your pooling mode must be one of these ['mean', 'sum', 'max']")
63
+
64
+ return outputs
65
+
66
+ def forward(
67
+ self,
68
+ input_values,
69
+ attention_mask=None,
70
+ output_attentions=None,
71
+ output_hidden_states=None,
72
+ return_dict=None,
73
+ labels=None,
74
+ ):
75
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
76
+ outputs = self.wav2vec2(
77
+ input_values,
78
+ attention_mask=attention_mask,
79
+ output_attentions=output_attentions,
80
+ output_hidden_states=output_hidden_states,
81
+ return_dict=return_dict,
82
+ )
83
+ hidden_states = outputs[0]
84
+ hidden_states = self.merged_strategy(hidden_states, mode="mean")
85
+ logits = self.classifier(hidden_states)
86
+
87
+ loss = None
88
+ if labels is not None:
89
+ if self.config.problem_type is None:
90
+ if self.num_labels == 1:
91
+ self.config.problem_type = "regression"
92
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
93
+ self.config.problem_type = "single_label_classification"
94
+ else:
95
+ self.config.problem_type = "multi_label_classification"
96
+
97
+ if self.config.problem_type == "regression":
98
+ loss_fct = MSELoss()
99
+ loss = loss_fct(logits.view(-1, self.num_labels), labels)
100
+ elif self.config.problem_type == "single_label_classification":
101
+ loss_fct = CrossEntropyLoss()
102
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
103
+ elif self.config.problem_type == "multi_label_classification":
104
+ loss_fct = BCEWithLogitsLoss()
105
+ loss = loss_fct(logits, labels)
106
+
107
+ if not return_dict:
108
+ output = (logits,) + outputs[2:]
109
+ return ((loss,) + output) if loss is not None else output
110
+
111
+ return SpeechClassifierOutput(
112
+ loss=loss,
113
+ logits=logits,
114
+ hidden_states=outputs.hidden_states,
115
+ attentions=outputs.attentions,
116
+ )
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:04b8e22b17bfeed234909f4a621e1227bc8711bd420877617ed058e8d6db72d8
3
- size 1266121461
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e2166446603c76e6b0af1f4f72448cd632198631b5ea89e3f57fe1c77402e241
3
+ size 1266115509