Update modeling_deberta.py
Browse files- modeling_deberta.py +36 -55
modeling_deberta.py
CHANGED
@@ -1058,8 +1058,6 @@ class DebertaV2Model(DebertaV2PreTrainedModel):
|
|
1058 |
)
|
1059 |
encoded_layers = list(encoder_outputs[1])
|
1060 |
|
1061 |
-
# print(self.z_steps)
|
1062 |
-
|
1063 |
if self.z_steps > 0:
|
1064 |
hidden_states = encoded_layers[-2]
|
1065 |
layers = [self.encoder.layer[-1] for _ in range(self.z_steps)]
|
@@ -1100,8 +1098,6 @@ class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel):
|
|
1100 |
self.deberta = DebertaV2Model(config)
|
1101 |
self.cls = DebertaV2OnlyMLMHead(config)
|
1102 |
|
1103 |
-
self.verbose = False
|
1104 |
-
|
1105 |
# Initialize weights and apply final processing
|
1106 |
self.post_init()
|
1107 |
|
@@ -1132,19 +1128,6 @@ class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel):
|
|
1132 |
|
1133 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1134 |
|
1135 |
-
if self.verbose:
|
1136 |
-
for i in input_ids[0, :].tolist():
|
1137 |
-
print(i, end=", ")
|
1138 |
-
print()
|
1139 |
-
if attention_mask is not None:
|
1140 |
-
for i in attention_mask[0, :].tolist():
|
1141 |
-
print(i, end=", ")
|
1142 |
-
print()
|
1143 |
-
if position_ids is not None:
|
1144 |
-
for i in position_ids[0, :].tolist():
|
1145 |
-
print(i, end=", ")
|
1146 |
-
print()
|
1147 |
-
|
1148 |
outputs = self.deberta(
|
1149 |
input_ids,
|
1150 |
attention_mask=attention_mask,
|
@@ -1183,6 +1166,7 @@ class DebertaV2ForCausalLM(DebertaV2ForMaskedLM):
|
|
1183 |
super().__init__(config)
|
1184 |
config.is_decoder = True
|
1185 |
self.mask_token_id = config.mask_token_id
|
|
|
1186 |
self.sep_token_id = config.sep_token_id
|
1187 |
self.n_masks = 3
|
1188 |
|
@@ -1200,12 +1184,39 @@ class DebertaV2ForCausalLM(DebertaV2ForMaskedLM):
|
|
1200 |
):
|
1201 |
position_ids = kwargs.get("position_ids", None)
|
1202 |
|
1203 |
-
|
1204 |
-
|
1205 |
-
|
1206 |
-
|
1207 |
-
|
1208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1209 |
|
1210 |
# Omit tokens covered by past_key_values
|
1211 |
if past_key_values is not None:
|
@@ -1228,7 +1239,7 @@ class DebertaV2ForCausalLM(DebertaV2ForMaskedLM):
|
|
1228 |
{
|
1229 |
"position_ids": position_ids,
|
1230 |
"past_key_values": past_key_values,
|
1231 |
-
"use_cache":
|
1232 |
"attention_mask": attention_mask,
|
1233 |
}
|
1234 |
)
|
@@ -1255,36 +1266,6 @@ class DebertaV2ForCausalLM(DebertaV2ForMaskedLM):
|
|
1255 |
assert past_key_values is None, "past_key_values is not supported for now"
|
1256 |
assert use_cache is None, "use_cache is not supported for now"
|
1257 |
|
1258 |
-
assert input_ids[0, -1] != self.sep_token_id, "remove the last token if it is a sep token"
|
1259 |
-
|
1260 |
-
batch_size, seq_length = input_ids.shape
|
1261 |
-
input_ids = torch.cat(
|
1262 |
-
[
|
1263 |
-
input_ids,
|
1264 |
-
torch.full((batch_size, self.n_masks), self.mask_token_id, device=input_ids.device),
|
1265 |
-
torch.full((batch_size, 1), self.sep_token_id, device=input_ids.device)
|
1266 |
-
],
|
1267 |
-
dim=-1
|
1268 |
-
)
|
1269 |
-
|
1270 |
-
if attention_mask is not None:
|
1271 |
-
attention_mask = torch.cat(
|
1272 |
-
[
|
1273 |
-
attention_mask,
|
1274 |
-
torch.full((batch_size, self.n_masks + 1), attention_mask[0, -1], device=attention_mask.device),
|
1275 |
-
],
|
1276 |
-
dim=-1
|
1277 |
-
)
|
1278 |
-
|
1279 |
-
if position_ids is not None:
|
1280 |
-
position_ids = torch.cat(
|
1281 |
-
[
|
1282 |
-
position_ids,
|
1283 |
-
torch.arange(0, self.n_masks + 1, device=position_ids.device).unsqueeze(0) + position_ids[:, -1:],
|
1284 |
-
],
|
1285 |
-
dim=-1
|
1286 |
-
)
|
1287 |
-
|
1288 |
outputs = super().forward(
|
1289 |
input_ids,
|
1290 |
attention_mask=attention_mask,
|
@@ -1297,7 +1278,7 @@ class DebertaV2ForCausalLM(DebertaV2ForMaskedLM):
|
|
1297 |
)
|
1298 |
|
1299 |
# shift the outputs and skip excess masks
|
1300 |
-
logits = outputs.logits[:,
|
1301 |
|
1302 |
loss = None
|
1303 |
if labels is not None:
|
|
|
1058 |
)
|
1059 |
encoded_layers = list(encoder_outputs[1])
|
1060 |
|
|
|
|
|
1061 |
if self.z_steps > 0:
|
1062 |
hidden_states = encoded_layers[-2]
|
1063 |
layers = [self.encoder.layer[-1] for _ in range(self.z_steps)]
|
|
|
1098 |
self.deberta = DebertaV2Model(config)
|
1099 |
self.cls = DebertaV2OnlyMLMHead(config)
|
1100 |
|
|
|
|
|
1101 |
# Initialize weights and apply final processing
|
1102 |
self.post_init()
|
1103 |
|
|
|
1128 |
|
1129 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1131 |
outputs = self.deberta(
|
1132 |
input_ids,
|
1133 |
attention_mask=attention_mask,
|
|
|
1166 |
super().__init__(config)
|
1167 |
config.is_decoder = True
|
1168 |
self.mask_token_id = config.mask_token_id
|
1169 |
+
self.cls_token_id = config.cls_token_id
|
1170 |
self.sep_token_id = config.sep_token_id
|
1171 |
self.n_masks = 3
|
1172 |
|
|
|
1184 |
):
|
1185 |
position_ids = kwargs.get("position_ids", None)
|
1186 |
|
1187 |
+
assert input_ids[0, 0] != self.cls_token_id, "`add_special_tokens` should be set to `False`, but `[CLS]` token was detected"
|
1188 |
+
assert input_ids[0, -1] != self.sep_token_id, "`add_special_tokens` should be set to `False`, but `[SEP]` token was detected"
|
1189 |
+
|
1190 |
+
batch_size, seq_length = input_ids.shape
|
1191 |
+
input_ids = torch.cat(
|
1192 |
+
[
|
1193 |
+
torch.full((batch_size, 1), self.cls_token_id, device=input_ids.device)
|
1194 |
+
input_ids,
|
1195 |
+
torch.full((batch_size, self.n_masks), self.mask_token_id, device=input_ids.device),
|
1196 |
+
torch.full((batch_size, 1), self.sep_token_id, device=input_ids.device)
|
1197 |
+
],
|
1198 |
+
dim=-1
|
1199 |
+
)
|
1200 |
+
|
1201 |
+
if attention_mask is not None:
|
1202 |
+
attention_mask = torch.cat(
|
1203 |
+
[
|
1204 |
+
torch.full((batch_size, 1), attention_mask[0, 0], device=attention_mask.device),
|
1205 |
+
attention_mask,
|
1206 |
+
torch.full((batch_size, self.n_masks + 1), attention_mask[0, -1], device=attention_mask.device),
|
1207 |
+
],
|
1208 |
+
dim=-1
|
1209 |
+
)
|
1210 |
+
|
1211 |
+
if position_ids is not None:
|
1212 |
+
position_ids = torch.cat(
|
1213 |
+
[
|
1214 |
+
torch.zeros(batch_size, 1, device=position_ids.device),
|
1215 |
+
position_ids + 1,
|
1216 |
+
torch.arange(0, self.n_masks + 1, device=position_ids.device).unsqueeze(0) + position_ids[:, -1:] + 1,
|
1217 |
+
],
|
1218 |
+
dim=-1
|
1219 |
+
)
|
1220 |
|
1221 |
# Omit tokens covered by past_key_values
|
1222 |
if past_key_values is not None:
|
|
|
1239 |
{
|
1240 |
"position_ids": position_ids,
|
1241 |
"past_key_values": past_key_values,
|
1242 |
+
"use_cache": None,
|
1243 |
"attention_mask": attention_mask,
|
1244 |
}
|
1245 |
)
|
|
|
1266 |
assert past_key_values is None, "past_key_values is not supported for now"
|
1267 |
assert use_cache is None, "use_cache is not supported for now"
|
1268 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1269 |
outputs = super().forward(
|
1270 |
input_ids,
|
1271 |
attention_mask=attention_mask,
|
|
|
1278 |
)
|
1279 |
|
1280 |
# shift the outputs and skip excess masks
|
1281 |
+
logits = outputs.logits[:, 2:-self.n_masks, :].contiguous()
|
1282 |
|
1283 |
loss = None
|
1284 |
if labels is not None:
|