Fix wrong tuple count issue after reapply
Browse files- modeling_mpt.py +6 -8
modeling_mpt.py
CHANGED
|
@@ -248,7 +248,7 @@ class MPTModel(MPTPreTrainedModel):
|
|
| 248 |
|
| 249 |
return custom_forward
|
| 250 |
|
| 251 |
-
(x,
|
| 252 |
create_custom_forward(block),
|
| 253 |
x,
|
| 254 |
past_key_value,
|
|
@@ -256,15 +256,13 @@ class MPTModel(MPTPreTrainedModel):
|
|
| 256 |
attention_mask,
|
| 257 |
self.is_causal,
|
| 258 |
)
|
| 259 |
-
if past_key_values is not None:
|
| 260 |
-
past_key_values[b_idx] = past_key_value
|
| 261 |
else:
|
| 262 |
(x, attn_weights, present) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions))
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
|
| 269 |
|
| 270 |
x = self.norm_f(x)
|
|
|
|
| 248 |
|
| 249 |
return custom_forward
|
| 250 |
|
| 251 |
+
(x, attn_weights, present) = torch.utils.checkpoint.checkpoint(
|
| 252 |
create_custom_forward(block),
|
| 253 |
x,
|
| 254 |
past_key_value,
|
|
|
|
| 256 |
attention_mask,
|
| 257 |
self.is_causal,
|
| 258 |
)
|
|
|
|
|
|
|
| 259 |
else:
|
| 260 |
(x, attn_weights, present) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions))
|
| 261 |
+
if presents is not None:
|
| 262 |
+
presents += (present,)
|
| 263 |
+
if output_attentions:
|
| 264 |
+
assert all_self_attns is not None
|
| 265 |
+
all_self_attns = all_self_attns + (attn_weights,)
|
| 266 |
|
| 267 |
|
| 268 |
x = self.norm_f(x)
|