[Fix bug] TypeError: argument of type 'XLMRobertaFlashConfig' is not iterable (#55)
Browse files- [Fix bug] TypeError: argument of type 'XLMRobertaFlashConfig' is not iterable (7207e6dc1a4f92525661684f15d6778d84cfdf3c)
Co-authored-by: Le Khac Phuong <[email protected]>
- modeling_lora.py +15 -13
modeling_lora.py
CHANGED
|
@@ -11,16 +11,12 @@ from torch.nn import Parameter
|
|
| 11 |
from torch.nn import functional as F
|
| 12 |
from transformers import PretrainedConfig
|
| 13 |
|
| 14 |
-
from .rotary import RotaryEmbedding
|
| 15 |
-
from .mlp import FusedMLP, Mlp
|
| 16 |
-
from .xlm_padding import index_first_axis_residual, pad_input, unpad_input
|
| 17 |
-
from .stochastic_depth import stochastic_depth
|
| 18 |
-
from .mha import MHA
|
| 19 |
-
from .block import Block
|
| 20 |
from .configuration_xlm_roberta import XLMRobertaFlashConfig
|
| 21 |
-
from .
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
def initialized_weights(
|
|
@@ -336,7 +332,7 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
| 336 |
**kwargs,
|
| 337 |
):
|
| 338 |
for key in list(kwargs.keys()):
|
| 339 |
-
if key in config:
|
| 340 |
config.update({key: kwargs.pop(key)})
|
| 341 |
if config.load_trained_adapters: # checkpoint already contains LoRA adapters
|
| 342 |
return super().from_pretrained(
|
|
@@ -350,11 +346,14 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
| 350 |
token=token,
|
| 351 |
revision=revision,
|
| 352 |
use_safetensors=use_safetensors,
|
| 353 |
-
**kwargs
|
| 354 |
)
|
| 355 |
else: # initializing new adapters
|
| 356 |
roberta = XLMRobertaModel.from_pretrained(
|
| 357 |
-
pretrained_model_name_or_path,
|
|
|
|
|
|
|
|
|
|
| 358 |
)
|
| 359 |
return cls(config, roberta=roberta)
|
| 360 |
|
|
@@ -418,7 +417,10 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
| 418 |
if isinstance(sentences, str):
|
| 419 |
sentences = self._task_instructions[task] + sentences
|
| 420 |
else:
|
| 421 |
-
sentences = [
|
|
|
|
|
|
|
| 422 |
return self.roberta.encode(
|
| 423 |
sentences, *args, adapter_mask=adapter_mask, **kwargs
|
| 424 |
)
|
|
|
|
|
|
| 11 |
from torch.nn import functional as F
|
| 12 |
from transformers import PretrainedConfig
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
from .configuration_xlm_roberta import XLMRobertaFlashConfig
|
| 15 |
+
from .modeling_xlm_roberta import (
|
| 16 |
+
XLMRobertaFlashConfig,
|
| 17 |
+
XLMRobertaModel,
|
| 18 |
+
XLMRobertaPreTrainedModel,
|
| 19 |
+
)
|
| 20 |
|
| 21 |
|
| 22 |
def initialized_weights(
|
|
|
|
| 332 |
**kwargs,
|
| 333 |
):
|
| 334 |
for key in list(kwargs.keys()):
|
| 335 |
+
if key in config.to_dict():
|
| 336 |
config.update({key: kwargs.pop(key)})
|
| 337 |
if config.load_trained_adapters: # checkpoint already contains LoRA adapters
|
| 338 |
return super().from_pretrained(
|
|
|
|
| 346 |
token=token,
|
| 347 |
revision=revision,
|
| 348 |
use_safetensors=use_safetensors,
|
| 349 |
+
**kwargs,
|
| 350 |
)
|
| 351 |
else: # initializing new adapters
|
| 352 |
roberta = XLMRobertaModel.from_pretrained(
|
| 353 |
+
pretrained_model_name_or_path,
|
| 354 |
+
*model_args,
|
| 355 |
+
use_flash_attn=config.use_flash_attn,
|
| 356 |
+
**kwargs,
|
| 357 |
)
|
| 358 |
return cls(config, roberta=roberta)
|
| 359 |
|
|
|
|
| 417 |
if isinstance(sentences, str):
|
| 418 |
sentences = self._task_instructions[task] + sentences
|
| 419 |
else:
|
| 420 |
+
sentences = [
|
| 421 |
+
self._task_instructions[task] + sentence for sentence in sentences
|
| 422 |
+
]
|
| 423 |
return self.roberta.encode(
|
| 424 |
sentences, *args, adapter_mask=adapter_mask, **kwargs
|
| 425 |
)
|
| 426 |
+
|