rntc's picture
Upload CamemBERT-v2 multitask classifier checkpoint-49500
a791a0b verified
raw
history blame
788 Bytes
from transformers import AutoConfig, PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class MultiTaskClsConfig(PretrainedConfig):
model_type = "multitaskcls"
def __init__(
self,
problem_types=None,
labels_list=None,
label2id_dict=None,
id2label_dict=None,
**kwargs
):
super().__init__(**kwargs)
# create attributes from the keys in kwargs
for key, value in kwargs.items():
setattr(self, key, value)
self.num_tasks = len(labels_list) if labels_list is not None else 0
self.labels_list = labels_list
self.problem_types = problem_types
self.label2id_dict = label2id_dict
self.id2label_dict = id2label_dict