| from transformers import AutoConfig, PretrainedConfig | |
| from transformers.utils import logging | |
| logger = logging.get_logger(__name__) | |
| class MultiTaskClsConfig(PretrainedConfig): | |
| model_type = "multitaskcls" | |
| def __init__( | |
| self, | |
| problem_types=None, | |
| labels_list=None, | |
| label2id_dict=None, | |
| id2label_dict=None, | |
| **kwargs | |
| ): | |
| super().__init__(**kwargs) | |
| # create attributes from the keys in kwargs | |
| for key, value in kwargs.items(): | |
| setattr(self, key, value) | |
| self.num_tasks = len(labels_list) if labels_list is not None else 0 | |
| self.labels_list = labels_list | |
| self.problem_types = problem_types | |
| self.label2id_dict = label2id_dict | |
| self.id2label_dict = id2label_dict | |