zhihan1996 commited on
Commit
1ce72a5
1 Parent(s): 1c51ecc

Update bert_layers.py

Browse files
Files changed (1) hide show
  1. bert_layers.py +4 -3
bert_layers.py CHANGED
@@ -18,6 +18,7 @@ from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
18
  from transformers.activations import ACT2FN
19
  from transformers.modeling_outputs import (MaskedLMOutput,
20
  SequenceClassifierOutput)
 
21
  from transformers.modeling_utils import PreTrainedModel
22
 
23
  from .bert_padding import (index_first_axis,
@@ -521,7 +522,7 @@ class BertPredictionHeadTransform(nn.Module):
521
  return hidden_states
522
 
523
 
524
- class BertModel(PreTrainedModel):
525
  """Overall BERT model.
526
 
527
  Args:
@@ -681,7 +682,7 @@ class BertOnlyNSPHead(nn.Module):
681
 
682
 
683
 
684
- class BertForMaskedLM(PreTrainedModel):
685
 
686
  def __init__(self, config):
687
  super().__init__(config)
@@ -810,7 +811,7 @@ class BertForMaskedLM(PreTrainedModel):
810
 
811
 
812
 
813
- class BertForSequenceClassification(PreTrainedModel):
814
  """Bert Model transformer with a sequence classification/regression head.
815
 
816
  This head is just a linear layer on top of the pooled output. Used for,
 
18
  from transformers.activations import ACT2FN
19
  from transformers.modeling_outputs import (MaskedLMOutput,
20
  SequenceClassifierOutput)
21
+ from transformers.modeling_bert import BertPreTrainedModel
22
  from transformers.modeling_utils import PreTrainedModel
23
 
24
  from .bert_padding import (index_first_axis,
 
522
  return hidden_states
523
 
524
 
525
+ class BertModel(BertPreTrainedModel):
526
  """Overall BERT model.
527
 
528
  Args:
 
682
 
683
 
684
 
685
+ class BertForMaskedLM(BertPreTrainedModel):
686
 
687
  def __init__(self, config):
688
  super().__init__(config)
 
811
 
812
 
813
 
814
+ class BertForSequenceClassification(BertPreTrainedModel):
815
  """Bert Model transformer with a sequence classification/regression head.
816
 
817
  This head is just a linear layer on top of the pooled output. Used for,