rntc commited on
Commit
a791a0b
·
verified ·
1 Parent(s): 3a27f88

Upload CamemBERT-v2 multitask classifier checkpoint-49500

Browse files
config.json ADDED
@@ -0,0 +1,526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MultiTaskClsModel"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "bos_token_id": 1,
7
+ "classifier_dropout": null,
8
+ "embedding_size": 768,
9
+ "eos_token_id": 2,
10
+ "finetuning_task": "text-classification",
11
+ "hidden_act": "gelu",
12
+ "hidden_dropout_prob": 0.1,
13
+ "hidden_size": 768,
14
+ "id2label_dict": {
15
+ "age_group": {
16
+ "0": "adult",
17
+ "1": "elderly",
18
+ "2": "not_specified",
19
+ "3": "pediatric"
20
+ },
21
+ "assertion_type": {
22
+ "0": "factual",
23
+ "1": "hypothetical",
24
+ "2": "mixed",
25
+ "3": "opinion",
26
+ "4": "recommendation"
27
+ },
28
+ "certainty_level": {
29
+ "0": "definitive",
30
+ "1": "possible",
31
+ "2": "probable",
32
+ "3": "uncertain"
33
+ },
34
+ "contains_abbreviations": {
35
+ "0": "0",
36
+ "1": "1"
37
+ },
38
+ "contains_bias": {
39
+ "0": "0",
40
+ "1": "1"
41
+ },
42
+ "contains_numbers": {
43
+ "0": "0",
44
+ "1": "1"
45
+ },
46
+ "content_novelty": {
47
+ "0": "established",
48
+ "1": "outdated",
49
+ "2": "recent_developments"
50
+ },
51
+ "content_richness": {
52
+ "0": "1",
53
+ "1": "2",
54
+ "2": "3",
55
+ "3": "4",
56
+ "4": "5"
57
+ },
58
+ "content_type": {
59
+ "0": "background_review",
60
+ "1": "clinical_guidance",
61
+ "2": "drug_information",
62
+ "3": "medical_knowledge",
63
+ "4": "other",
64
+ "5": "patient_case",
65
+ "6": "policy_administrative",
66
+ "7": "research_findings",
67
+ "8": "research_methodology"
68
+ },
69
+ "educational_score": {
70
+ "0": "1",
71
+ "1": "2",
72
+ "2": "3",
73
+ "3": "4",
74
+ "4": "5"
75
+ },
76
+ "interactive_elements": {
77
+ "0": "instructions",
78
+ "1": "none",
79
+ "2": "questions",
80
+ "3": "tasks"
81
+ },
82
+ "list_format": {
83
+ "0": "0",
84
+ "1": "1"
85
+ },
86
+ "medical_subfield": {
87
+ "0": "anatomical_pathology",
88
+ "1": "anesthesiology",
89
+ "2": "biology_medicine",
90
+ "3": "cardiology",
91
+ "4": "dentistry",
92
+ "5": "dermatology",
93
+ "6": "digestive_surgery",
94
+ "7": "endocrinology",
95
+ "8": "gastroenterology",
96
+ "9": "general_medicine",
97
+ "10": "general_surgery",
98
+ "11": "genetics",
99
+ "12": "geriatrics",
100
+ "13": "gynecology_medical",
101
+ "14": "gynecology_obstetrics",
102
+ "15": "hematology",
103
+ "16": "intensive_care",
104
+ "17": "internal_medicine",
105
+ "18": "maxillofacial_surgery",
106
+ "19": "midwifery",
107
+ "20": "nephrology",
108
+ "21": "neurology",
109
+ "22": "neurosurgery",
110
+ "23": "nuclear_medicine",
111
+ "24": "occupational_medicine",
112
+ "25": "oncology",
113
+ "26": "ophthalmology",
114
+ "27": "oral_surgery",
115
+ "28": "orthodontics",
116
+ "29": "orthopedic_surgery",
117
+ "30": "other",
118
+ "31": "otolaryngology",
119
+ "32": "pediatric_surgery",
120
+ "33": "pediatrics",
121
+ "34": "pharmacy",
122
+ "35": "plastic_surgery",
123
+ "36": "pneumology",
124
+ "37": "psychiatry",
125
+ "38": "public_health",
126
+ "39": "radiology",
127
+ "40": "rehabilitation",
128
+ "41": "rheumatology",
129
+ "42": "thoracic_surgery",
130
+ "43": "urologic_surgery",
131
+ "44": "vascular_surgery"
132
+ },
133
+ "pretraining_suitable": {
134
+ "0": "0",
135
+ "1": "1"
136
+ },
137
+ "rewriting_needed": {
138
+ "0": "0",
139
+ "1": "1"
140
+ },
141
+ "sex": {
142
+ "0": "female",
143
+ "1": "male",
144
+ "2": "not_specified"
145
+ },
146
+ "terminology_precision": {
147
+ "0": "1",
148
+ "1": "2",
149
+ "2": "3",
150
+ "3": "4",
151
+ "4": "5"
152
+ },
153
+ "text_type": {
154
+ "0": "incomplete",
155
+ "1": "meaningful"
156
+ },
157
+ "writing_quality": {
158
+ "0": "1",
159
+ "1": "2",
160
+ "2": "3",
161
+ "3": "4",
162
+ "4": "5"
163
+ },
164
+ "writing_style": {
165
+ "0": "academic",
166
+ "1": "clinical",
167
+ "2": "other",
168
+ "3": "pedagogical",
169
+ "4": "regulatory"
170
+ }
171
+ },
172
+ "initializer_range": 0.02,
173
+ "intermediate_size": 3072,
174
+ "label2id_dict": {
175
+ "age_group": {
176
+ "adult": 0,
177
+ "elderly": 1,
178
+ "not_specified": 2,
179
+ "pediatric": 3
180
+ },
181
+ "assertion_type": {
182
+ "factual": 0,
183
+ "hypothetical": 1,
184
+ "mixed": 2,
185
+ "opinion": 3,
186
+ "recommendation": 4
187
+ },
188
+ "certainty_level": {
189
+ "definitive": 0,
190
+ "possible": 1,
191
+ "probable": 2,
192
+ "uncertain": 3
193
+ },
194
+ "contains_abbreviations": {
195
+ "0": 0,
196
+ "1": 1
197
+ },
198
+ "contains_bias": {
199
+ "0": 0,
200
+ "1": 1
201
+ },
202
+ "contains_numbers": {
203
+ "0": 0,
204
+ "1": 1
205
+ },
206
+ "content_novelty": {
207
+ "established": 0,
208
+ "outdated": 1,
209
+ "recent_developments": 2
210
+ },
211
+ "content_richness": {
212
+ "1": 0,
213
+ "2": 1,
214
+ "3": 2,
215
+ "4": 3,
216
+ "5": 4
217
+ },
218
+ "content_type": {
219
+ "background_review": 0,
220
+ "clinical_guidance": 1,
221
+ "drug_information": 2,
222
+ "medical_knowledge": 3,
223
+ "other": 4,
224
+ "patient_case": 5,
225
+ "policy_administrative": 6,
226
+ "research_findings": 7,
227
+ "research_methodology": 8
228
+ },
229
+ "educational_score": {
230
+ "1": 0,
231
+ "2": 1,
232
+ "3": 2,
233
+ "4": 3,
234
+ "5": 4
235
+ },
236
+ "interactive_elements": {
237
+ "instructions": 0,
238
+ "none": 1,
239
+ "questions": 2,
240
+ "tasks": 3
241
+ },
242
+ "list_format": {
243
+ "0": 0,
244
+ "1": 1
245
+ },
246
+ "medical_subfield": {
247
+ "anatomical_pathology": 0,
248
+ "anesthesiology": 1,
249
+ "biology_medicine": 2,
250
+ "cardiology": 3,
251
+ "dentistry": 4,
252
+ "dermatology": 5,
253
+ "digestive_surgery": 6,
254
+ "endocrinology": 7,
255
+ "gastroenterology": 8,
256
+ "general_medicine": 9,
257
+ "general_surgery": 10,
258
+ "genetics": 11,
259
+ "geriatrics": 12,
260
+ "gynecology_medical": 13,
261
+ "gynecology_obstetrics": 14,
262
+ "hematology": 15,
263
+ "intensive_care": 16,
264
+ "internal_medicine": 17,
265
+ "maxillofacial_surgery": 18,
266
+ "midwifery": 19,
267
+ "nephrology": 20,
268
+ "neurology": 21,
269
+ "neurosurgery": 22,
270
+ "nuclear_medicine": 23,
271
+ "occupational_medicine": 24,
272
+ "oncology": 25,
273
+ "ophthalmology": 26,
274
+ "oral_surgery": 27,
275
+ "orthodontics": 28,
276
+ "orthopedic_surgery": 29,
277
+ "other": 30,
278
+ "otolaryngology": 31,
279
+ "pediatric_surgery": 32,
280
+ "pediatrics": 33,
281
+ "pharmacy": 34,
282
+ "plastic_surgery": 35,
283
+ "pneumology": 36,
284
+ "psychiatry": 37,
285
+ "public_health": 38,
286
+ "radiology": 39,
287
+ "rehabilitation": 40,
288
+ "rheumatology": 41,
289
+ "thoracic_surgery": 42,
290
+ "urologic_surgery": 43,
291
+ "vascular_surgery": 44
292
+ },
293
+ "pretraining_suitable": {
294
+ "0": 0,
295
+ "1": 1
296
+ },
297
+ "rewriting_needed": {
298
+ "0": 0,
299
+ "1": 1
300
+ },
301
+ "sex": {
302
+ "female": 0,
303
+ "male": 1,
304
+ "not_specified": 2
305
+ },
306
+ "terminology_precision": {
307
+ "1": 0,
308
+ "2": 1,
309
+ "3": 2,
310
+ "4": 3,
311
+ "5": 4
312
+ },
313
+ "text_type": {
314
+ "incomplete": 0,
315
+ "meaningful": 1
316
+ },
317
+ "writing_quality": {
318
+ "1": 0,
319
+ "2": 1,
320
+ "3": 2,
321
+ "4": 3,
322
+ "5": 4
323
+ },
324
+ "writing_style": {
325
+ "academic": 0,
326
+ "clinical": 1,
327
+ "other": 2,
328
+ "pedagogical": 3,
329
+ "regulatory": 4
330
+ }
331
+ },
332
+ "labels_list": [
333
+ [
334
+ "1",
335
+ "2",
336
+ "3",
337
+ "4",
338
+ "5"
339
+ ],
340
+ [
341
+ "1",
342
+ "2",
343
+ "3",
344
+ "4",
345
+ "5"
346
+ ],
347
+ [
348
+ "1",
349
+ "2",
350
+ "3",
351
+ "4",
352
+ "5"
353
+ ],
354
+ [
355
+ "1",
356
+ "2",
357
+ "3",
358
+ "4",
359
+ "5"
360
+ ],
361
+ [
362
+ "0",
363
+ "1"
364
+ ],
365
+ [
366
+ "0",
367
+ "1"
368
+ ],
369
+ [
370
+ "0",
371
+ "1"
372
+ ],
373
+ [
374
+ "academic",
375
+ "clinical",
376
+ "other",
377
+ "pedagogical",
378
+ "regulatory"
379
+ ],
380
+ [
381
+ "background_review",
382
+ "clinical_guidance",
383
+ "drug_information",
384
+ "medical_knowledge",
385
+ "other",
386
+ "patient_case",
387
+ "policy_administrative",
388
+ "research_findings",
389
+ "research_methodology"
390
+ ],
391
+ [
392
+ "anatomical_pathology",
393
+ "anesthesiology",
394
+ "biology_medicine",
395
+ "cardiology",
396
+ "dentistry",
397
+ "dermatology",
398
+ "digestive_surgery",
399
+ "endocrinology",
400
+ "gastroenterology",
401
+ "general_medicine",
402
+ "general_surgery",
403
+ "genetics",
404
+ "geriatrics",
405
+ "gynecology_medical",
406
+ "gynecology_obstetrics",
407
+ "hematology",
408
+ "intensive_care",
409
+ "internal_medicine",
410
+ "maxillofacial_surgery",
411
+ "midwifery",
412
+ "nephrology",
413
+ "neurology",
414
+ "neurosurgery",
415
+ "nuclear_medicine",
416
+ "occupational_medicine",
417
+ "oncology",
418
+ "ophthalmology",
419
+ "oral_surgery",
420
+ "orthodontics",
421
+ "orthopedic_surgery",
422
+ "other",
423
+ "otolaryngology",
424
+ "pediatric_surgery",
425
+ "pediatrics",
426
+ "pharmacy",
427
+ "plastic_surgery",
428
+ "pneumology",
429
+ "psychiatry",
430
+ "public_health",
431
+ "radiology",
432
+ "rehabilitation",
433
+ "rheumatology",
434
+ "thoracic_surgery",
435
+ "urologic_surgery",
436
+ "vascular_surgery"
437
+ ],
438
+ [
439
+ "adult",
440
+ "elderly",
441
+ "not_specified",
442
+ "pediatric"
443
+ ],
444
+ [
445
+ "female",
446
+ "male",
447
+ "not_specified"
448
+ ],
449
+ [
450
+ "factual",
451
+ "hypothetical",
452
+ "mixed",
453
+ "opinion",
454
+ "recommendation"
455
+ ],
456
+ [
457
+ "definitive",
458
+ "possible",
459
+ "probable",
460
+ "uncertain"
461
+ ],
462
+ [
463
+ "0",
464
+ "1"
465
+ ],
466
+ [
467
+ "0",
468
+ "1"
469
+ ],
470
+ [
471
+ "0",
472
+ "1"
473
+ ],
474
+ [
475
+ "instructions",
476
+ "none",
477
+ "questions",
478
+ "tasks"
479
+ ],
480
+ [
481
+ "established",
482
+ "outdated",
483
+ "recent_developments"
484
+ ],
485
+ [
486
+ "incomplete",
487
+ "meaningful"
488
+ ]
489
+ ],
490
+ "layer_norm_eps": 1e-07,
491
+ "max_position_embeddings": 1025,
492
+ "model_name": "camembertv2-base",
493
+ "model_type": "roberta",
494
+ "num_attention_heads": 12,
495
+ "num_hidden_layers": 12,
496
+ "pad_token_id": 0,
497
+ "position_biased_input": true,
498
+ "position_embedding_type": "absolute",
499
+ "problem_types": [
500
+ "single_label_classification",
501
+ "single_label_classification",
502
+ "single_label_classification",
503
+ "single_label_classification",
504
+ "single_label_classification",
505
+ "single_label_classification",
506
+ "single_label_classification",
507
+ "single_label_classification",
508
+ "single_label_classification",
509
+ "single_label_classification",
510
+ "single_label_classification",
511
+ "single_label_classification",
512
+ "single_label_classification",
513
+ "single_label_classification",
514
+ "single_label_classification",
515
+ "single_label_classification",
516
+ "single_label_classification",
517
+ "single_label_classification",
518
+ "single_label_classification",
519
+ "single_label_classification"
520
+ ],
521
+ "torch_dtype": "float32",
522
+ "transformers_version": "4.55.0",
523
+ "type_vocab_size": 1,
524
+ "use_cache": true,
525
+ "vocab_size": 32768
526
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:34b2a11e468c40c80b7d6cc7c451b6e2f95f80a544ed8597fa2e7452d976b8d5
3
+ size 449148280
multitask_transformer/__pycache__/configuration_multitask.cpython-312.pyc ADDED
Binary file (1.26 kB). View file
 
multitask_transformer/__pycache__/modeling_multitask.cpython-312.pyc ADDED
Binary file (10.1 kB). View file
 
multitask_transformer/configuration_multitask.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig, PretrainedConfig
2
+ from transformers.utils import logging
3
+
4
+ logger = logging.get_logger(__name__)
5
+
6
+
7
+ class MultiTaskClsConfig(PretrainedConfig):
8
+ model_type = "multitaskcls"
9
+
10
+ def __init__(
11
+ self,
12
+ problem_types=None,
13
+ labels_list=None,
14
+ label2id_dict=None,
15
+ id2label_dict=None,
16
+ **kwargs
17
+ ):
18
+ super().__init__(**kwargs)
19
+ # create attributes from the keys in kwargs
20
+ for key, value in kwargs.items():
21
+ setattr(self, key, value)
22
+ self.num_tasks = len(labels_list) if labels_list is not None else 0
23
+ self.labels_list = labels_list
24
+ self.problem_types = problem_types
25
+ self.label2id_dict = label2id_dict
26
+ self.id2label_dict = id2label_dict
multitask_transformer/modeling_multitask.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from transformers import AutoModel, PreTrainedModel
8
+ from transformers.modeling_outputs import SequenceClassifierOutput
9
+ from transformers.models.auto.modeling_auto import MODEL_MAPPING_NAMES
10
+ from transformers.utils import ModelOutput, logging
11
+
12
+ from .configuration_multitask import MultiTaskClsConfig
13
+
14
+ logger = logging.get_logger(__name__)
15
+
16
+
17
+ @dataclass
18
+ class MultiTaskSequenceClassifierOutput(ModelOutput):
19
+ """
20
+ Base class for outputs of sentence classification models.
21
+
22
+ Args:
23
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
24
+ Classification (or regression if config.num_labels==1) loss.
25
+ logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
26
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
27
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
28
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
29
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
30
+
31
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
32
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
33
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
34
+ sequence_length)`.
35
+
36
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
37
+ heads.
38
+ """
39
+
40
+ loss: Optional[torch.FloatTensor] = None
41
+ logits_list: List[torch.FloatTensor] = None
42
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
43
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
44
+
45
+
46
+ class MultiTaskClsModel(PreTrainedModel):
47
+ config_class = MultiTaskClsConfig
48
+
49
+ def __init__(self, config: MultiTaskClsConfig):
50
+ super().__init__(config)
51
+ model_cls_str = MODEL_MAPPING_NAMES[config.model_type]
52
+ model_cls = getattr(importlib.import_module("transformers"), model_cls_str)
53
+ transformer_encoder = model_cls._from_config(config)
54
+ self.model_prefix = transformer_encoder.base_model_prefix
55
+ # create a variable with the same name as the prefix
56
+ setattr(self, self.model_prefix, transformer_encoder)
57
+
58
+ classifier_dropout = (
59
+ config.classifier_dropout
60
+ if config.classifier_dropout is not None
61
+ else config.hidden_dropout_prob
62
+ )
63
+
64
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
65
+
66
+ self.dropout = nn.Dropout(classifier_dropout)
67
+
68
+ self.num_tasks = len(config.problem_types)
69
+ self.labels_list = config.labels_list
70
+ self.num_labels = [
71
+ len(labels) if labels is not None else 1 for labels in self.labels_list
72
+ ]
73
+ self.problem_types = (
74
+ [None] * self.num_tasks
75
+ if config.problem_types is None
76
+ else config.problem_types
77
+ )
78
+ self.cls_task_heads = nn.ModuleList(
79
+ [
80
+ nn.Linear(self.config.hidden_size, _num_labels)
81
+ for _num_labels in self.num_labels
82
+ ]
83
+ )
84
+
85
+ # Initialize weights and apply final processing
86
+ self.post_init()
87
+
88
+ def _init_weights(self, module):
89
+ """Initialize the weights"""
90
+ if isinstance(module, nn.Linear):
91
+ # Slightly different from the TF version which uses truncated_normal for initialization
92
+ # cf https://github.com/pytorch/pytorch/pull/5617
93
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
94
+ if module.bias is not None:
95
+ module.bias.data.zero_()
96
+ elif isinstance(module, nn.Embedding):
97
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
98
+ if module.padding_idx is not None:
99
+ module.weight.data[module.padding_idx].zero_()
100
+ elif isinstance(module, nn.LayerNorm):
101
+ module.bias.data.zero_()
102
+ module.weight.data.fill_(1.0)
103
+
104
+ def forward(
105
+ self,
106
+ input_ids: Optional[torch.Tensor] = None,
107
+ attention_mask: Optional[torch.Tensor] = None,
108
+ token_type_ids: Optional[torch.Tensor] = None,
109
+ position_ids: Optional[torch.Tensor] = None,
110
+ head_mask: Optional[torch.Tensor] = None,
111
+ inputs_embeds: Optional[torch.Tensor] = None,
112
+ labels: Optional[List[torch.Tensor]] = None,
113
+ output_attentions: Optional[bool] = None,
114
+ output_hidden_states: Optional[bool] = None,
115
+ return_dict: Optional[bool] = None,
116
+ ) -> Union[Tuple[torch.Tensor], List[MultiTaskSequenceClassifierOutput]]:
117
+ r"""
118
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
119
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
120
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
121
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
122
+ """
123
+ return_dict = (
124
+ return_dict if return_dict is not None else self.config.use_return_dict
125
+ )
126
+
127
+ # get attributes from the self.model_prefix
128
+ transformer_encoder = getattr(self, self.model_prefix)
129
+
130
+ outputs = transformer_encoder(
131
+ input_ids,
132
+ attention_mask=attention_mask,
133
+ token_type_ids=token_type_ids,
134
+ position_ids=position_ids,
135
+ head_mask=head_mask,
136
+ inputs_embeds=inputs_embeds,
137
+ output_attentions=output_attentions,
138
+ output_hidden_states=output_hidden_states,
139
+ return_dict=return_dict,
140
+ )
141
+
142
+ pooled_output = outputs[1]
143
+
144
+ pooled_output = self.dropout(pooled_output)
145
+
146
+ # List of logits for each task
147
+ logits_list = [task_head(pooled_output) for task_head in self.cls_task_heads]
148
+ losses = []
149
+ loss = None
150
+ if labels is not None:
151
+ for logits, task_labels, task_type, num_labels in zip(
152
+ logits_list, labels, self.problem_types, self.num_labels
153
+ ):
154
+ if task_type is None:
155
+ if num_labels == 1:
156
+ task_type = "regression"
157
+ elif num_labels > 1 and (
158
+ task_labels.dtype == torch.long
159
+ or task_labels.dtype == torch.int
160
+ ):
161
+ task_type = "single_label_classification"
162
+ else:
163
+ task_type = "multi_label_classification"
164
+
165
+ if task_type == "regression":
166
+ loss_fct = nn.MSELoss()
167
+ if num_labels == 1:
168
+ loss = loss_fct(logits.squeeze(), task_labels.squeeze())
169
+ else:
170
+ loss = loss_fct(logits, task_labels)
171
+ elif task_type == "single_label_classification":
172
+ loss_fct = nn.CrossEntropyLoss()
173
+ if task_labels.shape == logits.view(-1, num_labels).shape:
174
+ loss = loss_fct(logits.view(-1, num_labels), task_labels)
175
+ else:
176
+ loss = loss_fct(
177
+ logits.view(-1, num_labels), task_labels.view(-1)
178
+ )
179
+ elif task_type == "multi_label_classification":
180
+ loss_fct = nn.BCEWithLogitsLoss()
181
+ loss = loss_fct(logits, task_labels)
182
+ else:
183
+ raise ValueError(f"Task type '{task_type}' not supported")
184
+
185
+ losses.append(loss)
186
+
187
+ loss = torch.stack(losses).sum()
188
+
189
+ if not return_dict:
190
+ output = (logits_list,) + outputs[2:]
191
+ return ((loss,) + output) if loss is not None else output
192
+
193
+ return MultiTaskSequenceClassifierOutput(
194
+ loss=loss,
195
+ logits_list=logits_list,
196
+ hidden_states=outputs.hidden_states,
197
+ attentions=outputs.attentions,
198
+ )
special_tokens_map.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "[CLS]",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "cls_token": {
10
+ "content": "[CLS]",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "eos_token": {
17
+ "content": "[SEP]",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "mask_token": {
24
+ "content": "[MASK]",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "pad_token": {
31
+ "content": "[PAD]",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ },
37
+ "sep_token": {
38
+ "content": "[SEP]",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false
43
+ },
44
+ "unk_token": {
45
+ "content": "[UNK]",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false
50
+ }
51
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": true,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "[PAD]",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "1": {
13
+ "content": "[CLS]",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "2": {
21
+ "content": "[SEP]",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "3": {
29
+ "content": "[UNK]",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "4": {
37
+ "content": "[MASK]",
38
+ "lstrip": false,
39
+ "normalized": false,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": true
43
+ }
44
+ },
45
+ "bos_token": "[CLS]",
46
+ "clean_up_tokenization_spaces": true,
47
+ "cls_token": "[CLS]",
48
+ "eos_token": "[SEP]",
49
+ "errors": "replace",
50
+ "extra_special_tokens": {},
51
+ "mask_token": "[MASK]",
52
+ "model_max_length": 1024,
53
+ "pad_token": "[PAD]",
54
+ "sep_token": "[SEP]",
55
+ "tokenizer_class": "RobertaTokenizer",
56
+ "trim_offsets": true,
57
+ "unk_token": "[UNK]"
58
+ }
vocab.txt ADDED
The diff for this file is too large to render. See raw diff