fix: override use_flash_attn in lora
Browse files- modeling_lora.py +1 -4
modeling_lora.py
CHANGED
|
@@ -322,12 +322,9 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
| 322 |
use_safetensors: bool = None,
|
| 323 |
**kwargs,
|
| 324 |
):
|
| 325 |
-
config = XLMRobertaFlashConfig.from_pretrained(
|
| 326 |
-
pretrained_model_name_or_path, *model_args, **kwargs
|
| 327 |
-
)
|
| 328 |
if config.load_trained_adapters: # checkpoint already contains LoRA adapters
|
| 329 |
return super().from_pretrained(
|
| 330 |
-
pretrained_model_name_or_path, *model_args, **kwargs
|
| 331 |
)
|
| 332 |
else: # initializing new adapters
|
| 333 |
roberta = XLMRobertaModel.from_pretrained(
|
|
|
|
| 322 |
use_safetensors: bool = None,
|
| 323 |
**kwargs,
|
| 324 |
):
|
|
|
|
|
|
|
|
|
|
| 325 |
if config.load_trained_adapters: # checkpoint already contains LoRA adapters
|
| 326 |
return super().from_pretrained(
|
| 327 |
+
pretrained_model_name_or_path, *model_args, use_flash_attn=config.use_flash_attn, **kwargs
|
| 328 |
)
|
| 329 |
else: # initializing new adapters
|
| 330 |
roberta = XLMRobertaModel.from_pretrained(
|