Update modeling_xlm_roberta.py
Browse files- modeling_xlm_roberta.py +4 -6
modeling_xlm_roberta.py
CHANGED
|
@@ -210,12 +210,10 @@ class XLMRobertaEncoder(nn.Module):
|
|
| 210 |
subset_mask: (batch, seqlen), dtype=torch.bool
|
| 211 |
"""
|
| 212 |
if key_padding_mask is None or not self.use_flash_attn:
|
| 213 |
-
mixer_kwargs =
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
)
|
| 218 |
-
mixer_kwargs['task_type'] = task_type
|
| 219 |
for layer in self.layers:
|
| 220 |
if self._grad_checkpointing:
|
| 221 |
hidden_states = torch.utils.checkpoint.checkpoint(
|
|
|
|
| 210 |
subset_mask: (batch, seqlen), dtype=torch.bool
|
| 211 |
"""
|
| 212 |
if key_padding_mask is None or not self.use_flash_attn:
|
| 213 |
+
mixer_kwargs = {'task_type': task_type}
|
| 214 |
+
if key_padding_mask is not None:
|
| 215 |
+
mixer_kwargs['key_padding_mask'] = key_padding_mask.bool()
|
| 216 |
+
|
|
|
|
|
|
|
| 217 |
for layer in self.layers:
|
| 218 |
if self._grad_checkpointing:
|
| 219 |
hidden_states = torch.utils.checkpoint.checkpoint(
|