speech-intensity-wav2vec / modeling_wav2vec2_ctc_and_intensity.py
Amirhossein75's picture
add modeling architecture
e339308
raw
history blame
3.69 kB
from __future__ import annotations
from typing import Optional, Union, Tuple
import torch
import torch.nn as nn
from transformers import Wav2Vec2ForCTC
from transformers.modeling_outputs import CausalLMOutput
try:
from transformers.modeling_outputs import CTCOutput # older versions
except ImportError:
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from transformers.modeling_outputs import ModelOutput
@dataclass
class CTCOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
class Wav2Vec2ForCTCAndIntensity(Wav2Vec2ForCTC):
"""Wav2Vec2-CTC with an additional regression head for intensity.
Pools the last hidden state with attention mask then MLP -> scalar.
"""
def __init__(self, config):
super().__init__(config)
hidden = config.hidden_size
self.intensity_head = nn.Sequential(
nn.Dropout(0.1),
nn.Linear(hidden, 128),
nn.GELU(),
nn.Linear(128, 1),
)
self.mse = nn.MSELoss()
def forward(
self,
input_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
intensity_value: Optional[torch.FloatTensor] = None,
lambda_intensity: float = 1.0,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = True,
return_dict: Optional[bool] = True,
**kwargs
) -> Union[Tuple, CTCOutput]:
outputs = super().forward(
input_values=input_values,
attention_mask=attention_mask,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=True,
return_dict=True,
**kwargs
)
# Use last hidden state for regression: (B, T, H)
hidden = outputs.hidden_states[-1] if outputs.hidden_states is not None else None
intensity_pred = None
if hidden is not None:
if attention_mask is not None:
# Masked mean pooling over time
mask = attention_mask.unsqueeze(-1).to(hidden.dtype) # (B, T, 1)
summed = (hidden * mask).sum(dim=1)
denom = mask.sum(dim=1).clamp(min=1.0)
pooled = summed / denom
else:
pooled = hidden.mean(dim=1)
intensity_pred = self.intensity_head(pooled).squeeze(-1)
ctc_loss = outputs.loss if getattr(outputs, "loss", None) is not None else None
intensity_loss = None
if (intensity_pred is not None) and (intensity_value is not None):
intensity_loss = self.mse(intensity_pred, intensity_value)
loss = None
if (ctc_loss is not None) and (intensity_loss is not None):
loss = ctc_loss + lambda_intensity * intensity_loss
elif ctc_loss is not None:
loss = ctc_loss
elif intensity_loss is not None:
loss = lambda_intensity * intensity_loss
if not return_dict:
out = list(outputs)
if intensity_pred is not None:
out.append(intensity_pred)
if loss is not None:
out[0] = loss
return tuple(out)
return CTCOutput(
loss=loss,
logits=outputs.logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)