disable `torch.jit` when using Ascend NPUs (#20)
Browse files- disable `torch.jit` when using Ascend NPUs (33d113454549f9ece51919bc6136d329887e9aa0)
Co-authored-by: Huazhong <[email protected]>
- modeling_chatglm.py +2 -2
modeling_chatglm.py
CHANGED
|
@@ -21,7 +21,7 @@ from transformers.modeling_outputs import (
|
|
| 21 |
SequenceClassifierOutputWithPast,
|
| 22 |
)
|
| 23 |
from transformers.modeling_utils import PreTrainedModel
|
| 24 |
-
from transformers.utils import logging
|
| 25 |
from transformers.generation.logits_process import LogitsProcessor
|
| 26 |
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
|
| 27 |
|
|
@@ -29,7 +29,7 @@ from .configuration_chatglm import ChatGLMConfig
|
|
| 29 |
|
| 30 |
# flags required to enable jit fusion kernels
|
| 31 |
|
| 32 |
-
if sys.platform != 'darwin':
|
| 33 |
torch._C._jit_set_profiling_mode(False)
|
| 34 |
torch._C._jit_set_profiling_executor(False)
|
| 35 |
torch._C._jit_override_can_fuse_on_cpu(True)
|
|
|
|
| 21 |
SequenceClassifierOutputWithPast,
|
| 22 |
)
|
| 23 |
from transformers.modeling_utils import PreTrainedModel
|
| 24 |
+
from transformers.utils import logging, is_torch_npu_available
|
| 25 |
from transformers.generation.logits_process import LogitsProcessor
|
| 26 |
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
|
| 27 |
|
|
|
|
| 29 |
|
| 30 |
# flags required to enable jit fusion kernels
|
| 31 |
|
| 32 |
+
if sys.platform != 'darwin' and not is_torch_npu_available():
|
| 33 |
torch._C._jit_set_profiling_mode(False)
|
| 34 |
torch._C._jit_set_profiling_executor(False)
|
| 35 |
torch._C._jit_override_can_fuse_on_cpu(True)
|