import importlib from dataclasses import dataclass from typing import List, Optional, Tuple, Union import torch import torch.nn as nn from transformers import AutoModel, PreTrainedModel from transformers.modeling_outputs import SequenceClassifierOutput from transformers.models.auto.modeling_auto import MODEL_MAPPING_NAMES from transformers.utils import ModelOutput, logging from .configuration_multitask import MultiTaskClsConfig logger = logging.get_logger(__name__) @dataclass class MultiTaskSequenceClassifierOutput(ModelOutput): """ Base class for outputs of sentence classification models. Args: loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Classification (or regression if config.num_labels==1) loss. logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): Classification (or regression if config.num_labels==1) scores (before SoftMax). hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ loss: Optional[torch.FloatTensor] = None logits_list: List[torch.FloatTensor] = None hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None attentions: Optional[Tuple[torch.FloatTensor, ...]] = None class MultiTaskClsModel(PreTrainedModel): config_class = MultiTaskClsConfig def __init__(self, config: MultiTaskClsConfig): super().__init__(config) model_cls_str = MODEL_MAPPING_NAMES[config.model_type] model_cls = getattr(importlib.import_module("transformers"), model_cls_str) transformer_encoder = model_cls._from_config(config) self.model_prefix = transformer_encoder.base_model_prefix # create a variable with the same name as the prefix setattr(self, self.model_prefix, transformer_encoder) classifier_dropout = ( config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob ) self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dropout = nn.Dropout(classifier_dropout) self.num_tasks = len(config.problem_types) self.labels_list = config.labels_list self.num_labels = [ len(labels) if labels is not None else 1 for labels in self.labels_list ] self.problem_types = ( [None] * self.num_tasks if config.problem_types is None else config.problem_types ) self.cls_task_heads = nn.ModuleList( [ nn.Linear(self.config.hidden_size, _num_labels) for _num_labels in self.num_labels ] ) # Initialize weights and apply final processing self.post_init() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[List[torch.Tensor]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], List[MultiTaskSequenceClassifierOutput]]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) # get attributes from the self.model_prefix transformer_encoder = getattr(self, self.model_prefix) outputs = transformer_encoder( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) pooled_output = outputs[1] pooled_output = self.dropout(pooled_output) # List of logits for each task logits_list = [task_head(pooled_output) for task_head in self.cls_task_heads] losses = [] loss = None if labels is not None: for logits, task_labels, task_type, num_labels in zip( logits_list, labels, self.problem_types, self.num_labels ): if task_type is None: if num_labels == 1: task_type = "regression" elif num_labels > 1 and ( task_labels.dtype == torch.long or task_labels.dtype == torch.int ): task_type = "single_label_classification" else: task_type = "multi_label_classification" if task_type == "regression": loss_fct = nn.MSELoss() if num_labels == 1: loss = loss_fct(logits.squeeze(), task_labels.squeeze()) else: loss = loss_fct(logits, task_labels) elif task_type == "single_label_classification": loss_fct = nn.CrossEntropyLoss() if task_labels.shape == logits.view(-1, num_labels).shape: loss = loss_fct(logits.view(-1, num_labels), task_labels) else: loss = loss_fct( logits.view(-1, num_labels), task_labels.view(-1) ) elif task_type == "multi_label_classification": loss_fct = nn.BCEWithLogitsLoss() loss = loss_fct(logits, task_labels) else: raise ValueError(f"Task type '{task_type}' not supported") losses.append(loss) loss = torch.stack(losses).sum() if not return_dict: output = (logits_list,) + outputs[2:] return ((loss,) + output) if loss is not None else output return MultiTaskSequenceClassifierOutput( loss=loss, logits_list=logits_list, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )