Commit
·
c1c87bf
1
Parent(s):
af99e83
Add HAT implementation files
Browse files- modelling_hat.py +4 -9
modelling_hat.py
CHANGED
|
@@ -1839,8 +1839,6 @@ class HATForSequenceClassification(HATPreTrainedModel):
|
|
| 1839 |
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
| 1840 |
)
|
| 1841 |
self.dropout = nn.Dropout(classifier_dropout)
|
| 1842 |
-
if self.pooling != 'cls':
|
| 1843 |
-
self.sentencizer = HATSentencizer(config)
|
| 1844 |
self.pooler = HATPooler(config, pooling=pooling)
|
| 1845 |
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 1846 |
|
|
@@ -1885,13 +1883,12 @@ class HATForSequenceClassification(HATPreTrainedModel):
|
|
| 1885 |
return_dict=return_dict,
|
| 1886 |
)
|
| 1887 |
sequence_output = outputs[0]
|
| 1888 |
-
if self.pooling
|
| 1889 |
-
sentence_outputs = self.sentencizer(sequence_output)
|
| 1890 |
-
pooled_output = self.pooler(sentence_outputs)
|
| 1891 |
-
elif self.pooling == 'first':
|
| 1892 |
pooled_output = self.pooler(torch.unsqueeze(sequence_output[:, 0, :], 1))
|
| 1893 |
elif self.pooling == 'last':
|
| 1894 |
pooled_output = self.pooler(torch.unsqueeze(sequence_output[:, -128, :], 1))
|
|
|
|
|
|
|
| 1895 |
|
| 1896 |
pooled_output = self.dropout(pooled_output)
|
| 1897 |
logits = self.classifier(pooled_output)
|
|
@@ -2051,8 +2048,6 @@ class HATForMultipleChoice(HATPreTrainedModel):
|
|
| 2051 |
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
| 2052 |
)
|
| 2053 |
self.dropout = nn.Dropout(classifier_dropout)
|
| 2054 |
-
if self.pooling not in ['first', 'last']:
|
| 2055 |
-
self.sentencizer = HATSentencizer(config)
|
| 2056 |
self.pooler = HATPooler(config, pooling=pooling)
|
| 2057 |
self.classifier = nn.Linear(config.hidden_size, 1)
|
| 2058 |
|
|
@@ -2113,7 +2108,7 @@ class HATForMultipleChoice(HATPreTrainedModel):
|
|
| 2113 |
elif self.pooling == 'last':
|
| 2114 |
pooled_output = self.pooler(torch.unsqueeze(sequence_output[:, -128, :], 1))
|
| 2115 |
else:
|
| 2116 |
-
pooled_output = self.pooler(self.
|
| 2117 |
|
| 2118 |
pooled_output = self.dropout(pooled_output)
|
| 2119 |
logits = self.classifier(pooled_output)
|
|
|
|
| 1839 |
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
| 1840 |
)
|
| 1841 |
self.dropout = nn.Dropout(classifier_dropout)
|
|
|
|
|
|
|
| 1842 |
self.pooler = HATPooler(config, pooling=pooling)
|
| 1843 |
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 1844 |
|
|
|
|
| 1883 |
return_dict=return_dict,
|
| 1884 |
)
|
| 1885 |
sequence_output = outputs[0]
|
| 1886 |
+
if self.pooling == 'first':
|
|
|
|
|
|
|
|
|
|
| 1887 |
pooled_output = self.pooler(torch.unsqueeze(sequence_output[:, 0, :], 1))
|
| 1888 |
elif self.pooling == 'last':
|
| 1889 |
pooled_output = self.pooler(torch.unsqueeze(sequence_output[:, -128, :], 1))
|
| 1890 |
+
else:
|
| 1891 |
+
pooled_output = self.pooler(sequence_output[:, ::self.max_sentence_length])
|
| 1892 |
|
| 1893 |
pooled_output = self.dropout(pooled_output)
|
| 1894 |
logits = self.classifier(pooled_output)
|
|
|
|
| 2048 |
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
| 2049 |
)
|
| 2050 |
self.dropout = nn.Dropout(classifier_dropout)
|
|
|
|
|
|
|
| 2051 |
self.pooler = HATPooler(config, pooling=pooling)
|
| 2052 |
self.classifier = nn.Linear(config.hidden_size, 1)
|
| 2053 |
|
|
|
|
| 2108 |
elif self.pooling == 'last':
|
| 2109 |
pooled_output = self.pooler(torch.unsqueeze(sequence_output[:, -128, :], 1))
|
| 2110 |
else:
|
| 2111 |
+
pooled_output = self.pooler(sequence_output[:, ::self.max_sentence_length])
|
| 2112 |
|
| 2113 |
pooled_output = self.dropout(pooled_output)
|
| 2114 |
logits = self.classifier(pooled_output)
|