Update modeling_llama2.py
Browse files- modeling_llama2.py +1 -1
modeling_llama2.py
CHANGED
@@ -22,7 +22,7 @@ from transformers.models.llama.modeling_llama import *
|
|
22 |
from transformers.configuration_utils import PretrainedConfig
|
23 |
from transformers.utils import logging
|
24 |
|
25 |
-
from .modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
26 |
from .configuration_mplug_owl2 import LlamaConfig
|
27 |
|
28 |
class MultiwayNetwork(nn.Module):
|
|
|
22 |
from transformers.configuration_utils import PretrainedConfig
|
23 |
from transformers.utils import logging
|
24 |
|
25 |
+
from .modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
|
26 |
from .configuration_mplug_owl2 import LlamaConfig
|
27 |
|
28 |
class MultiwayNetwork(nn.Module):
|