Commit
·
162074d
1
Parent(s):
c1c87bf
Add HAT implementation files
Browse files- modelling_hat.py +4 -1
modelling_hat.py
CHANGED
|
@@ -1186,6 +1186,7 @@ class HATModelForDocumentRepresentation(HATPreTrainedModel):
|
|
| 1186 |
super().__init__(config)
|
| 1187 |
self.num_labels = config.num_labels
|
| 1188 |
self.config = config
|
|
|
|
| 1189 |
|
| 1190 |
self.hi_transformer = HATModel(config)
|
| 1191 |
self.pooler = HATPooler(config, pooling=pooling)
|
|
@@ -1233,7 +1234,7 @@ class HATModelForDocumentRepresentation(HATPreTrainedModel):
|
|
| 1233 |
return_dict=return_dict,
|
| 1234 |
)
|
| 1235 |
sequence_output = outputs[0]
|
| 1236 |
-
pooled_outputs = self.pooler(sequence_output)
|
| 1237 |
|
| 1238 |
drp_loss = None
|
| 1239 |
if labels is not None:
|
|
@@ -1832,6 +1833,7 @@ class HATForSequenceClassification(HATPreTrainedModel):
|
|
| 1832 |
super().__init__(config)
|
| 1833 |
self.num_labels = config.num_labels
|
| 1834 |
self.config = config
|
|
|
|
| 1835 |
self.pooling = pooling
|
| 1836 |
|
| 1837 |
self.hi_transformer = HATModel(config)
|
|
@@ -2043,6 +2045,7 @@ class HATForMultipleChoice(HATPreTrainedModel):
|
|
| 2043 |
super().__init__(config)
|
| 2044 |
|
| 2045 |
self.pooling = pooling
|
|
|
|
| 2046 |
self.hi_transformer = HATModel(config)
|
| 2047 |
classifier_dropout = (
|
| 2048 |
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
|
|
|
| 1186 |
super().__init__(config)
|
| 1187 |
self.num_labels = config.num_labels
|
| 1188 |
self.config = config
|
| 1189 |
+
self.max_sentence_length = config.max_sentence_length
|
| 1190 |
|
| 1191 |
self.hi_transformer = HATModel(config)
|
| 1192 |
self.pooler = HATPooler(config, pooling=pooling)
|
|
|
|
| 1234 |
return_dict=return_dict,
|
| 1235 |
)
|
| 1236 |
sequence_output = outputs[0]
|
| 1237 |
+
pooled_outputs = self.pooler(sequence_output[:, ::self.max_sentence_length])
|
| 1238 |
|
| 1239 |
drp_loss = None
|
| 1240 |
if labels is not None:
|
|
|
|
| 1833 |
super().__init__(config)
|
| 1834 |
self.num_labels = config.num_labels
|
| 1835 |
self.config = config
|
| 1836 |
+
self.max_sentence_length = config.max_sentence_length
|
| 1837 |
self.pooling = pooling
|
| 1838 |
|
| 1839 |
self.hi_transformer = HATModel(config)
|
|
|
|
| 2045 |
super().__init__(config)
|
| 2046 |
|
| 2047 |
self.pooling = pooling
|
| 2048 |
+
self.max_sentence_length = config.max_sentence_length
|
| 2049 |
self.hi_transformer = HATModel(config)
|
| 2050 |
classifier_dropout = (
|
| 2051 |
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|