jupyterjazz
commited on
Commit
•
dc4080e
1
Parent(s):
169b7fb
fix: override use_flash_attn in lora
Browse files- modeling_lora.py +1 -4
modeling_lora.py
CHANGED
@@ -322,12 +322,9 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
322 |
use_safetensors: bool = None,
|
323 |
**kwargs,
|
324 |
):
|
325 |
-
config = XLMRobertaFlashConfig.from_pretrained(
|
326 |
-
pretrained_model_name_or_path, *model_args, **kwargs
|
327 |
-
)
|
328 |
if config.load_trained_adapters: # checkpoint already contains LoRA adapters
|
329 |
return super().from_pretrained(
|
330 |
-
pretrained_model_name_or_path, *model_args, **kwargs
|
331 |
)
|
332 |
else: # initializing new adapters
|
333 |
roberta = XLMRobertaModel.from_pretrained(
|
|
|
322 |
use_safetensors: bool = None,
|
323 |
**kwargs,
|
324 |
):
|
|
|
|
|
|
|
325 |
if config.load_trained_adapters: # checkpoint already contains LoRA adapters
|
326 |
return super().from_pretrained(
|
327 |
+
pretrained_model_name_or_path, *model_args, use_flash_attn=config.use_flash_attn, **kwargs
|
328 |
)
|
329 |
else: # initializing new adapters
|
330 |
roberta = XLMRobertaModel.from_pretrained(
|