fix: set fp32 when using cpu bc bf16 is slow (#44)
Browse files- fix: set fp32 when using cpu bc bf16 is slow (6787a0f57730d94a2dda30bf54ab96382ce09536)
configuration_xlm_roberta.py
CHANGED
|
@@ -126,3 +126,5 @@ class XLMRobertaFlashConfig(PretrainedConfig):
|
|
| 126 |
self.torch_dtype = getattr(torch, torch_dtype)
|
| 127 |
else:
|
| 128 |
self.torch_dtype = torch_dtype
|
|
|
|
|
|
|
|
|
| 126 |
self.torch_dtype = getattr(torch, torch_dtype)
|
| 127 |
else:
|
| 128 |
self.torch_dtype = torch_dtype
|
| 129 |
+
if not self.use_flash_attn or not torch.cuda.is_available():
|
| 130 |
+
self.torch_dtype = torch.float32
|