Update modeling_norbert.py
Browse files- modeling_norbert.py +13 -13
modeling_norbert.py
CHANGED
|
@@ -277,12 +277,12 @@ class NorbertPreTrainedModel(PreTrainedModel):
|
|
| 277 |
|
| 278 |
|
| 279 |
class NorbertModel(NorbertPreTrainedModel):
|
| 280 |
-
def __init__(self, config, add_mlm_layer=False):
|
| 281 |
-
super().__init__(config)
|
| 282 |
self.config = config
|
| 283 |
|
| 284 |
self.embedding = Embedding(config)
|
| 285 |
-
self.transformer = Encoder(config, activation_checkpointing=
|
| 286 |
self.classifier = MaskClassifier(config, self.embedding.word_embedding.weight) if add_mlm_layer else None
|
| 287 |
|
| 288 |
def get_input_embeddings(self):
|
|
@@ -352,8 +352,8 @@ class NorbertModel(NorbertPreTrainedModel):
|
|
| 352 |
class NorbertForMaskedLM(NorbertModel):
|
| 353 |
_keys_to_ignore_on_load_unexpected = ["head"]
|
| 354 |
|
| 355 |
-
def __init__(self, config):
|
| 356 |
-
super().__init__(config, add_mlm_layer=True)
|
| 357 |
|
| 358 |
def get_output_embeddings(self):
|
| 359 |
return self.classifier.nonlinearity[-1].weight
|
|
@@ -432,8 +432,8 @@ class NorbertForSequenceClassification(NorbertModel):
|
|
| 432 |
_keys_to_ignore_on_load_unexpected = ["classifier"]
|
| 433 |
_keys_to_ignore_on_load_missing = ["head"]
|
| 434 |
|
| 435 |
-
def __init__(self, config):
|
| 436 |
-
super().__init__(config, add_mlm_layer=False)
|
| 437 |
|
| 438 |
self.num_labels = config.num_labels
|
| 439 |
self.head = Classifier(config, self.num_labels)
|
|
@@ -498,8 +498,8 @@ class NorbertForTokenClassification(NorbertModel):
|
|
| 498 |
_keys_to_ignore_on_load_unexpected = ["classifier"]
|
| 499 |
_keys_to_ignore_on_load_missing = ["head"]
|
| 500 |
|
| 501 |
-
def __init__(self, config):
|
| 502 |
-
super().__init__(config, add_mlm_layer=False)
|
| 503 |
|
| 504 |
self.num_labels = config.num_labels
|
| 505 |
self.head = Classifier(config, self.num_labels)
|
|
@@ -546,8 +546,8 @@ class NorbertForQuestionAnswering(NorbertModel):
|
|
| 546 |
_keys_to_ignore_on_load_unexpected = ["classifier"]
|
| 547 |
_keys_to_ignore_on_load_missing = ["head"]
|
| 548 |
|
| 549 |
-
def __init__(self, config):
|
| 550 |
-
super().__init__(config, add_mlm_layer=False)
|
| 551 |
|
| 552 |
self.num_labels = config.num_labels
|
| 553 |
self.head = Classifier(config, self.num_labels)
|
|
@@ -614,8 +614,8 @@ class NorbertForMultipleChoice(NorbertModel):
|
|
| 614 |
_keys_to_ignore_on_load_unexpected = ["classifier"]
|
| 615 |
_keys_to_ignore_on_load_missing = ["head"]
|
| 616 |
|
| 617 |
-
def __init__(self, config):
|
| 618 |
-
super().__init__(config, add_mlm_layer=False)
|
| 619 |
|
| 620 |
self.num_labels = getattr(config, "num_labels", 2)
|
| 621 |
self.head = Classifier(config, self.num_labels)
|
|
|
|
| 277 |
|
| 278 |
|
| 279 |
class NorbertModel(NorbertPreTrainedModel):
|
| 280 |
+
def __init__(self, config, add_mlm_layer=False, gradient_checkpointing=False, **kwargs):
|
| 281 |
+
super().__init__(config, **kwargs)
|
| 282 |
self.config = config
|
| 283 |
|
| 284 |
self.embedding = Embedding(config)
|
| 285 |
+
self.transformer = Encoder(config, activation_checkpointing=gradient_checkpointing)
|
| 286 |
self.classifier = MaskClassifier(config, self.embedding.word_embedding.weight) if add_mlm_layer else None
|
| 287 |
|
| 288 |
def get_input_embeddings(self):
|
|
|
|
| 352 |
class NorbertForMaskedLM(NorbertModel):
|
| 353 |
_keys_to_ignore_on_load_unexpected = ["head"]
|
| 354 |
|
| 355 |
+
def __init__(self, config, **kwargs):
|
| 356 |
+
super().__init__(config, add_mlm_layer=True, **kwargs)
|
| 357 |
|
| 358 |
def get_output_embeddings(self):
|
| 359 |
return self.classifier.nonlinearity[-1].weight
|
|
|
|
| 432 |
_keys_to_ignore_on_load_unexpected = ["classifier"]
|
| 433 |
_keys_to_ignore_on_load_missing = ["head"]
|
| 434 |
|
| 435 |
+
def __init__(self, config, **kwargs):
|
| 436 |
+
super().__init__(config, add_mlm_layer=False, **kwargs)
|
| 437 |
|
| 438 |
self.num_labels = config.num_labels
|
| 439 |
self.head = Classifier(config, self.num_labels)
|
|
|
|
| 498 |
_keys_to_ignore_on_load_unexpected = ["classifier"]
|
| 499 |
_keys_to_ignore_on_load_missing = ["head"]
|
| 500 |
|
| 501 |
+
def __init__(self, config, **kwargs):
|
| 502 |
+
super().__init__(config, add_mlm_layer=False, **kwargs)
|
| 503 |
|
| 504 |
self.num_labels = config.num_labels
|
| 505 |
self.head = Classifier(config, self.num_labels)
|
|
|
|
| 546 |
_keys_to_ignore_on_load_unexpected = ["classifier"]
|
| 547 |
_keys_to_ignore_on_load_missing = ["head"]
|
| 548 |
|
| 549 |
+
def __init__(self, config, **kwargs):
|
| 550 |
+
super().__init__(config, add_mlm_layer=False, **kwargs)
|
| 551 |
|
| 552 |
self.num_labels = config.num_labels
|
| 553 |
self.head = Classifier(config, self.num_labels)
|
|
|
|
| 614 |
_keys_to_ignore_on_load_unexpected = ["classifier"]
|
| 615 |
_keys_to_ignore_on_load_missing = ["head"]
|
| 616 |
|
| 617 |
+
def __init__(self, config, **kwargs):
|
| 618 |
+
super().__init__(config, add_mlm_layer=False, **kwargs)
|
| 619 |
|
| 620 |
self.num_labels = getattr(config, "num_labels", 2)
|
| 621 |
self.head = Classifier(config, self.num_labels)
|