Commit
·
e3a99d1
1
Parent(s):
895ac06
Add HAT implementation files
Browse files- modelling_hat.py +20 -0
modelling_hat.py
CHANGED
|
@@ -1093,6 +1093,26 @@ class HATForMaskedLM(HATPreTrainedModel):
|
|
| 1093 |
def set_output_embeddings(self, new_embeddings):
|
| 1094 |
self.lm_head.decoder = new_embeddings
|
| 1095 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1096 |
@add_start_docstrings_to_model_forward(HAT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1097 |
@add_code_sample_docstrings(
|
| 1098 |
processor_class=_TOKENIZER_FOR_DOC,
|
|
|
|
| 1093 |
def set_output_embeddings(self, new_embeddings):
|
| 1094 |
self.lm_head.decoder = new_embeddings
|
| 1095 |
|
| 1096 |
+
def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
|
| 1097 |
+
"""Tie or clone module weights depending of whether we are using TorchScript or not"""
|
| 1098 |
+
if self.config.torchscript:
|
| 1099 |
+
output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone())
|
| 1100 |
+
else:
|
| 1101 |
+
output_embeddings.weight = input_embeddings.weight
|
| 1102 |
+
|
| 1103 |
+
if getattr(output_embeddings, "bias", None) is not None:
|
| 1104 |
+
output_embeddings.bias.data = nn.functional.pad(
|
| 1105 |
+
output_embeddings.bias.data,
|
| 1106 |
+
(
|
| 1107 |
+
0,
|
| 1108 |
+
output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0],
|
| 1109 |
+
),
|
| 1110 |
+
"constant",
|
| 1111 |
+
0,
|
| 1112 |
+
)
|
| 1113 |
+
if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
|
| 1114 |
+
output_embeddings.out_features = input_embeddings.num_embeddings
|
| 1115 |
+
|
| 1116 |
@add_start_docstrings_to_model_forward(HAT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1117 |
@add_code_sample_docstrings(
|
| 1118 |
processor_class=_TOKENIZER_FOR_DOC,
|