ltg
/

davda54 commited on
Commit
cd99491
1 Parent(s): ac3afa1

Update modeling_deberta.py

Browse files
Files changed (1) hide show
  1. modeling_deberta.py +36 -55
modeling_deberta.py CHANGED
@@ -1058,8 +1058,6 @@ class DebertaV2Model(DebertaV2PreTrainedModel):
1058
  )
1059
  encoded_layers = list(encoder_outputs[1])
1060
 
1061
- # print(self.z_steps)
1062
-
1063
  if self.z_steps > 0:
1064
  hidden_states = encoded_layers[-2]
1065
  layers = [self.encoder.layer[-1] for _ in range(self.z_steps)]
@@ -1100,8 +1098,6 @@ class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel):
1100
  self.deberta = DebertaV2Model(config)
1101
  self.cls = DebertaV2OnlyMLMHead(config)
1102
 
1103
- self.verbose = False
1104
-
1105
  # Initialize weights and apply final processing
1106
  self.post_init()
1107
 
@@ -1132,19 +1128,6 @@ class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel):
1132
 
1133
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1134
 
1135
- if self.verbose:
1136
- for i in input_ids[0, :].tolist():
1137
- print(i, end=", ")
1138
- print()
1139
- if attention_mask is not None:
1140
- for i in attention_mask[0, :].tolist():
1141
- print(i, end=", ")
1142
- print()
1143
- if position_ids is not None:
1144
- for i in position_ids[0, :].tolist():
1145
- print(i, end=", ")
1146
- print()
1147
-
1148
  outputs = self.deberta(
1149
  input_ids,
1150
  attention_mask=attention_mask,
@@ -1183,6 +1166,7 @@ class DebertaV2ForCausalLM(DebertaV2ForMaskedLM):
1183
  super().__init__(config)
1184
  config.is_decoder = True
1185
  self.mask_token_id = config.mask_token_id
 
1186
  self.sep_token_id = config.sep_token_id
1187
  self.n_masks = 3
1188
 
@@ -1200,12 +1184,39 @@ class DebertaV2ForCausalLM(DebertaV2ForMaskedLM):
1200
  ):
1201
  position_ids = kwargs.get("position_ids", None)
1202
 
1203
- if input_ids[0, -1] == 2:
1204
- input_ids = input_ids[:, :-1]
1205
- if attention_mask is not None:
1206
- attention_mask = attention_mask[:, :-1]
1207
- if position_ids is not None:
1208
- position_ids = position_ids[:, :-1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1209
 
1210
  # Omit tokens covered by past_key_values
1211
  if past_key_values is not None:
@@ -1228,7 +1239,7 @@ class DebertaV2ForCausalLM(DebertaV2ForMaskedLM):
1228
  {
1229
  "position_ids": position_ids,
1230
  "past_key_values": past_key_values,
1231
- "use_cache": kwargs.get("use_cache"),
1232
  "attention_mask": attention_mask,
1233
  }
1234
  )
@@ -1255,36 +1266,6 @@ class DebertaV2ForCausalLM(DebertaV2ForMaskedLM):
1255
  assert past_key_values is None, "past_key_values is not supported for now"
1256
  assert use_cache is None, "use_cache is not supported for now"
1257
 
1258
- assert input_ids[0, -1] != self.sep_token_id, "remove the last token if it is a sep token"
1259
-
1260
- batch_size, seq_length = input_ids.shape
1261
- input_ids = torch.cat(
1262
- [
1263
- input_ids,
1264
- torch.full((batch_size, self.n_masks), self.mask_token_id, device=input_ids.device),
1265
- torch.full((batch_size, 1), self.sep_token_id, device=input_ids.device)
1266
- ],
1267
- dim=-1
1268
- )
1269
-
1270
- if attention_mask is not None:
1271
- attention_mask = torch.cat(
1272
- [
1273
- attention_mask,
1274
- torch.full((batch_size, self.n_masks + 1), attention_mask[0, -1], device=attention_mask.device),
1275
- ],
1276
- dim=-1
1277
- )
1278
-
1279
- if position_ids is not None:
1280
- position_ids = torch.cat(
1281
- [
1282
- position_ids,
1283
- torch.arange(0, self.n_masks + 1, device=position_ids.device).unsqueeze(0) + position_ids[:, -1:],
1284
- ],
1285
- dim=-1
1286
- )
1287
-
1288
  outputs = super().forward(
1289
  input_ids,
1290
  attention_mask=attention_mask,
@@ -1297,7 +1278,7 @@ class DebertaV2ForCausalLM(DebertaV2ForMaskedLM):
1297
  )
1298
 
1299
  # shift the outputs and skip excess masks
1300
- logits = outputs.logits[:, 1:-(self.n_masks), :].contiguous()
1301
 
1302
  loss = None
1303
  if labels is not None:
 
1058
  )
1059
  encoded_layers = list(encoder_outputs[1])
1060
 
 
 
1061
  if self.z_steps > 0:
1062
  hidden_states = encoded_layers[-2]
1063
  layers = [self.encoder.layer[-1] for _ in range(self.z_steps)]
 
1098
  self.deberta = DebertaV2Model(config)
1099
  self.cls = DebertaV2OnlyMLMHead(config)
1100
 
 
 
1101
  # Initialize weights and apply final processing
1102
  self.post_init()
1103
 
 
1128
 
1129
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1131
  outputs = self.deberta(
1132
  input_ids,
1133
  attention_mask=attention_mask,
 
1166
  super().__init__(config)
1167
  config.is_decoder = True
1168
  self.mask_token_id = config.mask_token_id
1169
+ self.cls_token_id = config.cls_token_id
1170
  self.sep_token_id = config.sep_token_id
1171
  self.n_masks = 3
1172
 
 
1184
  ):
1185
  position_ids = kwargs.get("position_ids", None)
1186
 
1187
+ assert input_ids[0, 0] != self.cls_token_id, "`add_special_tokens` should be set to `False`, but `[CLS]` token was detected"
1188
+ assert input_ids[0, -1] != self.sep_token_id, "`add_special_tokens` should be set to `False`, but `[SEP]` token was detected"
1189
+
1190
+ batch_size, seq_length = input_ids.shape
1191
+ input_ids = torch.cat(
1192
+ [
1193
+ torch.full((batch_size, 1), self.cls_token_id, device=input_ids.device)
1194
+ input_ids,
1195
+ torch.full((batch_size, self.n_masks), self.mask_token_id, device=input_ids.device),
1196
+ torch.full((batch_size, 1), self.sep_token_id, device=input_ids.device)
1197
+ ],
1198
+ dim=-1
1199
+ )
1200
+
1201
+ if attention_mask is not None:
1202
+ attention_mask = torch.cat(
1203
+ [
1204
+ torch.full((batch_size, 1), attention_mask[0, 0], device=attention_mask.device),
1205
+ attention_mask,
1206
+ torch.full((batch_size, self.n_masks + 1), attention_mask[0, -1], device=attention_mask.device),
1207
+ ],
1208
+ dim=-1
1209
+ )
1210
+
1211
+ if position_ids is not None:
1212
+ position_ids = torch.cat(
1213
+ [
1214
+ torch.zeros(batch_size, 1, device=position_ids.device),
1215
+ position_ids + 1,
1216
+ torch.arange(0, self.n_masks + 1, device=position_ids.device).unsqueeze(0) + position_ids[:, -1:] + 1,
1217
+ ],
1218
+ dim=-1
1219
+ )
1220
 
1221
  # Omit tokens covered by past_key_values
1222
  if past_key_values is not None:
 
1239
  {
1240
  "position_ids": position_ids,
1241
  "past_key_values": past_key_values,
1242
+ "use_cache": None,
1243
  "attention_mask": attention_mask,
1244
  }
1245
  )
 
1266
  assert past_key_values is None, "past_key_values is not supported for now"
1267
  assert use_cache is None, "use_cache is not supported for now"
1268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1269
  outputs = super().forward(
1270
  input_ids,
1271
  attention_mask=attention_mask,
 
1278
  )
1279
 
1280
  # shift the outputs and skip excess masks
1281
+ logits = outputs.logits[:, 2:-self.n_masks, :].contiguous()
1282
 
1283
  loss = None
1284
  if labels is not None: