Amirhossein75 commited on
Commit
e339308
·
1 Parent(s): 1e3e764

add modeling architecture

Browse files
config.json CHANGED
@@ -8,6 +8,7 @@
8
  "architectures": [
9
  "Wav2Vec2ForCTCAndIntensity"
10
  ],
 
11
  "attention_dropout": 0.1,
12
  "bos_token_id": 1,
13
  "classifier_proj_size": 256,
 
8
  "architectures": [
9
  "Wav2Vec2ForCTCAndIntensity"
10
  ],
11
+ "AutoModelForCTC": "Amirhossein75/speech-intensity-wav2vec--modeling_wav2vec2_ctc_and_intensity.Wav2Vec2ForCTCAndIntensity",
12
  "attention_dropout": 0.1,
13
  "bos_token_id": 1,
14
  "classifier_proj_size": 256,
modeling_wav2vec2_ctc_and_intensity.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import Optional, Union, Tuple
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import Wav2Vec2ForCTC
6
+ from transformers.modeling_outputs import CausalLMOutput
7
+ try:
8
+ from transformers.modeling_outputs import CTCOutput # older versions
9
+ except ImportError:
10
+ from dataclasses import dataclass
11
+ from typing import Optional, Tuple
12
+ import torch
13
+ from transformers.modeling_outputs import ModelOutput
14
+
15
+ @dataclass
16
+ class CTCOutput(ModelOutput):
17
+ loss: Optional[torch.FloatTensor] = None
18
+ logits: torch.FloatTensor = None
19
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
20
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
21
+ class Wav2Vec2ForCTCAndIntensity(Wav2Vec2ForCTC):
22
+ """Wav2Vec2-CTC with an additional regression head for intensity.
23
+ Pools the last hidden state with attention mask then MLP -> scalar.
24
+ """
25
+ def __init__(self, config):
26
+ super().__init__(config)
27
+ hidden = config.hidden_size
28
+ self.intensity_head = nn.Sequential(
29
+ nn.Dropout(0.1),
30
+ nn.Linear(hidden, 128),
31
+ nn.GELU(),
32
+ nn.Linear(128, 1),
33
+ )
34
+ self.mse = nn.MSELoss()
35
+
36
+ def forward(
37
+ self,
38
+ input_values: Optional[torch.FloatTensor] = None,
39
+ attention_mask: Optional[torch.Tensor] = None,
40
+ labels: Optional[torch.LongTensor] = None,
41
+ intensity_value: Optional[torch.FloatTensor] = None,
42
+ lambda_intensity: float = 1.0,
43
+ output_attentions: Optional[bool] = None,
44
+ output_hidden_states: Optional[bool] = True,
45
+ return_dict: Optional[bool] = True,
46
+ **kwargs
47
+ ) -> Union[Tuple, CTCOutput]:
48
+ outputs = super().forward(
49
+ input_values=input_values,
50
+ attention_mask=attention_mask,
51
+ labels=labels,
52
+ output_attentions=output_attentions,
53
+ output_hidden_states=True,
54
+ return_dict=True,
55
+ **kwargs
56
+ )
57
+
58
+ # Use last hidden state for regression: (B, T, H)
59
+ hidden = outputs.hidden_states[-1] if outputs.hidden_states is not None else None
60
+ intensity_pred = None
61
+ if hidden is not None:
62
+ if attention_mask is not None:
63
+ # Masked mean pooling over time
64
+ mask = attention_mask.unsqueeze(-1).to(hidden.dtype) # (B, T, 1)
65
+ summed = (hidden * mask).sum(dim=1)
66
+ denom = mask.sum(dim=1).clamp(min=1.0)
67
+ pooled = summed / denom
68
+ else:
69
+ pooled = hidden.mean(dim=1)
70
+ intensity_pred = self.intensity_head(pooled).squeeze(-1)
71
+
72
+ ctc_loss = outputs.loss if getattr(outputs, "loss", None) is not None else None
73
+ intensity_loss = None
74
+ if (intensity_pred is not None) and (intensity_value is not None):
75
+ intensity_loss = self.mse(intensity_pred, intensity_value)
76
+
77
+ loss = None
78
+ if (ctc_loss is not None) and (intensity_loss is not None):
79
+ loss = ctc_loss + lambda_intensity * intensity_loss
80
+ elif ctc_loss is not None:
81
+ loss = ctc_loss
82
+ elif intensity_loss is not None:
83
+ loss = lambda_intensity * intensity_loss
84
+
85
+ if not return_dict:
86
+ out = list(outputs)
87
+ if intensity_pred is not None:
88
+ out.append(intensity_pred)
89
+ if loss is not None:
90
+ out[0] = loss
91
+ return tuple(out)
92
+
93
+ return CTCOutput(
94
+ loss=loss,
95
+ logits=outputs.logits,
96
+ hidden_states=outputs.hidden_states,
97
+ attentions=outputs.attentions,
98
+ )