|
from __future__ import annotations |
|
from typing import Optional, Union, Tuple |
|
import torch |
|
import torch.nn as nn |
|
from transformers import Wav2Vec2ForCTC |
|
from transformers.modeling_outputs import CausalLMOutput |
|
try: |
|
from transformers.modeling_outputs import CTCOutput |
|
except ImportError: |
|
from dataclasses import dataclass |
|
from typing import Optional, Tuple |
|
import torch |
|
from transformers.modeling_outputs import ModelOutput |
|
|
|
@dataclass |
|
class CTCOutput(ModelOutput): |
|
loss: Optional[torch.FloatTensor] = None |
|
logits: torch.FloatTensor = None |
|
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None |
|
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None |
|
class Wav2Vec2ForCTCAndIntensity(Wav2Vec2ForCTC): |
|
"""Wav2Vec2-CTC with an additional regression head for intensity. |
|
Pools the last hidden state with attention mask then MLP -> scalar. |
|
""" |
|
def __init__(self, config): |
|
super().__init__(config) |
|
hidden = config.hidden_size |
|
self.intensity_head = nn.Sequential( |
|
nn.Dropout(0.1), |
|
nn.Linear(hidden, 128), |
|
nn.GELU(), |
|
nn.Linear(128, 1), |
|
) |
|
self.mse = nn.MSELoss() |
|
|
|
def forward( |
|
self, |
|
input_values: Optional[torch.FloatTensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
intensity_value: Optional[torch.FloatTensor] = None, |
|
lambda_intensity: float = 1.0, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = True, |
|
return_dict: Optional[bool] = True, |
|
**kwargs |
|
) -> Union[Tuple, CTCOutput]: |
|
outputs = super().forward( |
|
input_values=input_values, |
|
attention_mask=attention_mask, |
|
labels=labels, |
|
output_attentions=output_attentions, |
|
output_hidden_states=True, |
|
return_dict=True, |
|
**kwargs |
|
) |
|
|
|
|
|
hidden = outputs.hidden_states[-1] if outputs.hidden_states is not None else None |
|
intensity_pred = None |
|
if hidden is not None: |
|
if attention_mask is not None: |
|
|
|
mask = attention_mask.unsqueeze(-1).to(hidden.dtype) |
|
summed = (hidden * mask).sum(dim=1) |
|
denom = mask.sum(dim=1).clamp(min=1.0) |
|
pooled = summed / denom |
|
else: |
|
pooled = hidden.mean(dim=1) |
|
intensity_pred = self.intensity_head(pooled).squeeze(-1) |
|
|
|
ctc_loss = outputs.loss if getattr(outputs, "loss", None) is not None else None |
|
intensity_loss = None |
|
if (intensity_pred is not None) and (intensity_value is not None): |
|
intensity_loss = self.mse(intensity_pred, intensity_value) |
|
|
|
loss = None |
|
if (ctc_loss is not None) and (intensity_loss is not None): |
|
loss = ctc_loss + lambda_intensity * intensity_loss |
|
elif ctc_loss is not None: |
|
loss = ctc_loss |
|
elif intensity_loss is not None: |
|
loss = lambda_intensity * intensity_loss |
|
|
|
if not return_dict: |
|
out = list(outputs) |
|
if intensity_pred is not None: |
|
out.append(intensity_pred) |
|
if loss is not None: |
|
out[0] = loss |
|
return tuple(out) |
|
|
|
return CTCOutput( |
|
loss=loss, |
|
logits=outputs.logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |