|
|
import importlib |
|
|
from dataclasses import dataclass |
|
|
from typing import List, Optional, Tuple, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import AutoModel, PreTrainedModel |
|
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
|
from transformers.models.auto.modeling_auto import MODEL_MAPPING_NAMES |
|
|
from transformers.utils import ModelOutput, logging |
|
|
|
|
|
from .configuration_multitask import MultiTaskClsConfig |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class MultiTaskSequenceClassifierOutput(ModelOutput): |
|
|
""" |
|
|
Base class for outputs of sentence classification models. |
|
|
|
|
|
Args: |
|
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): |
|
|
Classification (or regression if config.num_labels==1) loss. |
|
|
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): |
|
|
Classification (or regression if config.num_labels==1) scores (before SoftMax). |
|
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
|
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + |
|
|
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. |
|
|
|
|
|
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. |
|
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
|
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
|
|
sequence_length)`. |
|
|
|
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
|
|
heads. |
|
|
""" |
|
|
|
|
|
loss: Optional[torch.FloatTensor] = None |
|
|
logits_list: List[torch.FloatTensor] = None |
|
|
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None |
|
|
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None |
|
|
|
|
|
|
|
|
class MultiTaskClsModel(PreTrainedModel): |
|
|
config_class = MultiTaskClsConfig |
|
|
|
|
|
def __init__(self, config: MultiTaskClsConfig): |
|
|
super().__init__(config) |
|
|
model_cls_str = MODEL_MAPPING_NAMES[config.model_type] |
|
|
model_cls = getattr(importlib.import_module("transformers"), model_cls_str) |
|
|
transformer_encoder = model_cls._from_config(config) |
|
|
self.model_prefix = transformer_encoder.base_model_prefix |
|
|
|
|
|
setattr(self, self.model_prefix, transformer_encoder) |
|
|
|
|
|
classifier_dropout = ( |
|
|
config.classifier_dropout |
|
|
if config.classifier_dropout is not None |
|
|
else config.hidden_dropout_prob |
|
|
) |
|
|
|
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
|
|
|
self.dropout = nn.Dropout(classifier_dropout) |
|
|
|
|
|
self.num_tasks = len(config.problem_types) |
|
|
self.labels_list = config.labels_list |
|
|
self.num_labels = [ |
|
|
len(labels) if labels is not None else 1 for labels in self.labels_list |
|
|
] |
|
|
self.problem_types = ( |
|
|
[None] * self.num_tasks |
|
|
if config.problem_types is None |
|
|
else config.problem_types |
|
|
) |
|
|
self.cls_task_heads = nn.ModuleList( |
|
|
[ |
|
|
nn.Linear(self.config.hidden_size, _num_labels) |
|
|
for _num_labels in self.num_labels |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def _init_weights(self, module): |
|
|
"""Initialize the weights""" |
|
|
if isinstance(module, nn.Linear): |
|
|
|
|
|
|
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
|
if module.bias is not None: |
|
|
module.bias.data.zero_() |
|
|
elif isinstance(module, nn.Embedding): |
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
|
if module.padding_idx is not None: |
|
|
module.weight.data[module.padding_idx].zero_() |
|
|
elif isinstance(module, nn.LayerNorm): |
|
|
module.bias.data.zero_() |
|
|
module.weight.data.fill_(1.0) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
token_type_ids: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
head_mask: Optional[torch.Tensor] = None, |
|
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
|
labels: Optional[List[torch.Tensor]] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
) -> Union[Tuple[torch.Tensor], List[MultiTaskSequenceClassifierOutput]]: |
|
|
r""" |
|
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
|
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
|
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
|
|
""" |
|
|
return_dict = ( |
|
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
|
) |
|
|
|
|
|
|
|
|
transformer_encoder = getattr(self, self.model_prefix) |
|
|
|
|
|
outputs = transformer_encoder( |
|
|
input_ids, |
|
|
attention_mask=attention_mask, |
|
|
token_type_ids=token_type_ids, |
|
|
position_ids=position_ids, |
|
|
head_mask=head_mask, |
|
|
inputs_embeds=inputs_embeds, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
|
|
|
pooled_output = outputs[1] |
|
|
|
|
|
pooled_output = self.dropout(pooled_output) |
|
|
|
|
|
|
|
|
logits_list = [task_head(pooled_output) for task_head in self.cls_task_heads] |
|
|
losses = [] |
|
|
loss = None |
|
|
if labels is not None: |
|
|
for logits, task_labels, task_type, num_labels in zip( |
|
|
logits_list, labels, self.problem_types, self.num_labels |
|
|
): |
|
|
if task_type is None: |
|
|
if num_labels == 1: |
|
|
task_type = "regression" |
|
|
elif num_labels > 1 and ( |
|
|
task_labels.dtype == torch.long |
|
|
or task_labels.dtype == torch.int |
|
|
): |
|
|
task_type = "single_label_classification" |
|
|
else: |
|
|
task_type = "multi_label_classification" |
|
|
|
|
|
if task_type == "regression": |
|
|
loss_fct = nn.MSELoss() |
|
|
if num_labels == 1: |
|
|
loss = loss_fct(logits.squeeze(), task_labels.squeeze()) |
|
|
else: |
|
|
loss = loss_fct(logits, task_labels) |
|
|
elif task_type == "single_label_classification": |
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
if task_labels.shape == logits.view(-1, num_labels).shape: |
|
|
loss = loss_fct(logits.view(-1, num_labels), task_labels) |
|
|
else: |
|
|
loss = loss_fct( |
|
|
logits.view(-1, num_labels), task_labels.view(-1) |
|
|
) |
|
|
elif task_type == "multi_label_classification": |
|
|
loss_fct = nn.BCEWithLogitsLoss() |
|
|
loss = loss_fct(logits, task_labels) |
|
|
else: |
|
|
raise ValueError(f"Task type '{task_type}' not supported") |
|
|
|
|
|
losses.append(loss) |
|
|
|
|
|
loss = torch.stack(losses).sum() |
|
|
|
|
|
if not return_dict: |
|
|
output = (logits_list,) + outputs[2:] |
|
|
return ((loss,) + output) if loss is not None else output |
|
|
|
|
|
return MultiTaskSequenceClassifierOutput( |
|
|
loss=loss, |
|
|
logits_list=logits_list, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
) |
|
|
|