PEFT Finetuning Code Please
Hello
Thank you for making this work.
I wanted to check if you could please share the finetuning code using PEFT? I am running into an error and suspect that there may be issues in my finetuning code. Following is the code I am using:
model = AutoModelForCausalLM.from_pretrained('./mpt-7b-peft-compatible', trust_remote_code=True, torch_dtype=torch.bfloat16)
from torch import nn
for param in model.parameters():
param.requires_grad = False # freeze the model - train adapters later
model.gradient_checkpointing_enable() # reduce number of stored activations
model.enable_input_require_grads()
def print_trainable_parameters(model):
"""
Prints the number of trainable parameters in the model.
"""
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
print(
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
)
from peft import LoraConfig, get_peft_model
config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["up_proj", "down_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, config)
print_trainable_parameters(model)
output_dir = 'mpt-finetuned'
trainer = Trainer(
model=model,
train_dataset=train_ds_tokenized,
args=TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
warmup_steps=100,
max_steps=200,
learning_rate=2e-4,
fp16=True,
logging_steps=1,
output_dir=output_dir
),
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False)
)
trainer.train()
I get this error on trainer.train(): AttributeError: 'MPTForCausalLM' object has no attribute 'model_parallel'
Hi, try this notebook: https://colab.research.google.com/drive/1iBeY5UTLHE3aL6yNLiCIJHOBDqWBYbi5?usp=sharing
Works fine for me. If you'll still encounter issues, try updating some packages or let me know with the specific error.
Even when using this notebook? https://colab.research.google.com/drive/1iBeY5UTLHE3aL6yNLiCIJHOBDqWBYbi5?usp=sharing
Worked fine for me on Google colab.
Getting exactly the same error: 'MPTForCausalLM' object has no attribute 'model_parallel'
I am using a custom V100 setup
@envyt48
@cekal
I managed to solve the problem. You need to move the following code (already present in the above shared notebook at line 243-246) to before LoraConfig (line 192) initialization. Basically, the peft model's model_parallel attribute was being set but not for the underlying model. We need to set it for the underlying model
if not ddp and torch.cuda.device_count() > 1:
# keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
model.is_parallelizable = True
model.model_parallel = True
This is a copy of the edited notebook and this should work now: https://colab.research.google.com/drive/1A9bnjSfRQg6GciIkSxbfX7vS8PuP83FN
I ran the above notebook (had to modify the code that @kdua mentioned by explicitly setting the two parallel props to False (running on a single Colab GPU). Ran for about 5 hours and came up with an error trying to save but more concerning, the loss did not decrease. See screenshot:
The full error text is:
There were missing keys in the checkpoint model loaded: ['base_model.model.transformer.wte.weight', 'base_model.model.transformer.blocks.0.norm_1.weight', 'base_model.model.transformer.blocks.0.attn.Wqkv.weight', 'base_model.model.transformer.blocks.0.attn.out_proj.weight', 'base_model.model.transformer.blocks.0.norm_2.weight', 'base_model.model.transformer.blocks.0.ffn.up_proj.weight', 'base_model.model.transformer.blocks.0.ffn.down_proj.weight', 'base_model.model.transformer.blocks.1.norm_1.weight', 'base_model.model.transformer.blocks.1.attn.Wqkv.weight', 'base_model.model.transformer.blocks.1.attn.out_proj.weight', 'base_model.model.transformer.blocks.1.norm_2.weight', 'base_model.model.transformer.blocks.1.ffn.up_proj.weight', 'base_model.model.transformer.blocks.1.ffn.down_proj.weight', 'base_model.model.transformer.blocks.2.norm_1.weight', 'base_model.model.transformer.blocks.2.attn.Wqkv.weight', 'base_model.model.transformer.blocks.2.attn.out_proj.weight', 'base_model.model.transformer.blocks.2.norm_2.weight', 'base_model.model.transformer.blocks.2.ffn.up_proj.weight', 'base_model.model.transformer.blocks.2.ffn.down_proj.weight', 'base_model.model.transformer.blocks.3.norm_1.weight', 'base_model.model.transformer.blocks.3.attn.Wqkv.weight', 'base_model.model.transformer.blocks.3.attn.out_proj.weight', 'base_model.model.transformer.blocks.3.norm_2.weight', 'base_model.model.transformer.blocks.3.ffn.up_proj.weight', 'base_model.model.transformer.blocks.3.ffn.down_proj.weight', 'base_model.model.transformer.blocks.4.norm_1.weight', 'base_model.model.transformer.blocks.4.attn.Wqkv.weight', 'base_model.model.transformer.blocks.4.attn.out_proj.weight', 'base_model.model.transformer.blocks.4.norm_2.weight', 'base_model.model.transformer.blocks.4.ffn.up_proj.weight', 'base_model.model.transformer.blocks.4.ffn.down_proj.weight', 'base_model.model.transformer.blocks.5.norm_1.weight', 'base_model.model.transformer.blocks.5.attn.Wqkv.weight', 'base_model.model.transformer.blocks.5.attn.out_proj.weight', 'base_model.model.transformer.blocks.5.norm_2.weight', 'base_model.model.transformer.blocks.5.ffn.up_proj.weight', 'base_model.model.transformer.blocks.5.ffn.down_proj.weight', 'base_model.model.transformer.blocks.6.norm_1.weight', 'base_model.model.transformer.blocks.6.attn.Wqkv.weight', 'base_model.model.transformer.blocks.6.attn.out_proj.weight', 'base_model.model.transformer.blocks.6.norm_2.weight', 'base_model.model.transformer.blocks.6.ffn.up_proj.weight', 'base_model.model.transformer.blocks.6.ffn.down_proj.weight', 'base_model.model.transformer.blocks.7.norm_1.weight', 'base_model.model.transformer.blocks.7.attn.Wqkv.weight', 'base_model.model.transformer.blocks.7.attn.out_proj.weight', 'base_model.model.transformer.blocks.7.norm_2.weight', 'base_model.model.transformer.blocks.7.ffn.up_proj.weight', 'base_model.model.transformer.blocks.7.ffn.down_proj.weight', 'base_model.model.transformer.blocks.8.norm_1.weight', 'base_model.model.transformer.blocks.8.attn.Wqkv.weight', 'base_model.model.transformer.blocks.8.attn.out_proj.weight', 'base_model.model.transformer.blocks.8.norm_2.weight', 'base_model.model.transformer.blocks.8.ffn.up_proj.weight', 'base_model.model.transformer.blocks.8.ffn.down_proj.weight', 'base_model.model.transformer.blocks.9.norm_1.weight', 'base_model.model.transformer.blocks.9.attn.Wqkv.weight', 'base_model.model.transformer.blocks.9.attn.out_proj.weight', 'base_model.model.transformer.blocks.9.norm_2.weight', 'base_model.model.transformer.blocks.9.ffn.up_proj.weight', 'base_model.model.transformer.blocks.9.ffn.down_proj.weight', 'base_model.model.transformer.blocks.10.norm_1.weight', 'base_model.model.transformer.blocks.10.attn.Wqkv.weight', 'base_model.model.transformer.blocks.10.attn.out_proj.weight', 'base_model.model.transformer.blocks.10.norm_2.weight', 'base_model.model.transformer.blocks.10.ffn.up_proj.weight', 'base_model.model.transformer.blocks.10.ffn.down_proj.weight', 'base_model.model.transformer.blocks.11.norm_1.weight', 'base_model.model.transformer.blocks.11.attn.Wqkv.weight', 'base_model.model.transformer.blocks.11.attn.out_proj.weight', 'base_model.model.transformer.blocks.11.norm_2.weight', 'base_model.model.transformer.blocks.11.ffn.up_proj.weight', 'base_model.model.transformer.blocks.11.ffn.down_proj.weight', 'base_model.model.transformer.blocks.12.norm_1.weight', 'base_model.model.transformer.blocks.12.attn.Wqkv.weight', 'base_model.model.transformer.blocks.12.attn.out_proj.weight', 'base_model.model.transformer.blocks.12.norm_2.weight', 'base_model.model.transformer.blocks.12.ffn.up_proj.weight', 'base_model.model.transformer.blocks.12.ffn.down_proj.weight', 'base_model.model.transformer.blocks.13.norm_1.weight', 'base_model.model.transformer.blocks.13.attn.Wqkv.weight', 'base_model.model.transformer.blocks.13.attn.out_proj.weight', 'base_model.model.transformer.blocks.13.norm_2.weight', 'base_model.model.transformer.blocks.13.ffn.up_proj.weight', 'base_model.model.transformer.blocks.13.ffn.down_proj.weight', 'base_model.model.transformer.blocks.14.norm_1.weight', 'base_model.model.transformer.blocks.14.attn.Wqkv.weight', 'base_model.model.transformer.blocks.14.attn.out_proj.weight', 'base_model.model.transformer.blocks.14.norm_2.weight', 'base_model.model.transformer.blocks.14.ffn.up_proj.weight', 'base_model.model.transformer.blocks.14.ffn.down_proj.weight', 'base_model.model.transformer.blocks.15.norm_1.weight', 'base_model.model.transformer.blocks.15.attn.Wqkv.weight', 'base_model.model.transformer.blocks.15.attn.out_proj.weight', 'base_model.model.transformer.blocks.15.norm_2.weight', 'base_model.model.transformer.blocks.15.ffn.up_proj.weight', 'base_model.model.transformer.blocks.15.ffn.down_proj.weight', 'base_model.model.transformer.blocks.16.norm_1.weight', 'base_model.model.transformer.blocks.16.attn.Wqkv.weight', 'base_model.model.transformer.blocks.16.attn.out_proj.weight', 'base_model.model.transformer.blocks.16.norm_2.weight', 'base_model.model.transformer.blocks.16.ffn.up_proj.weight', 'base_model.model.transformer.blocks.16.ffn.down_proj.weight', 'base_model.model.transformer.blocks.17.norm_1.weight', 'base_model.model.transformer.blocks.17.attn.Wqkv.weight', 'base_model.model.transformer.blocks.17.attn.out_proj.weight', 'base_model.model.transformer.blocks.17.norm_2.weight', 'base_model.model.transformer.blocks.17.ffn.up_proj.weight', 'base_model.model.transformer.blocks.17.ffn.down_proj.weight', 'base_model.model.transformer.blocks.18.norm_1.weight', 'base_model.model.transformer.blocks.18.attn.Wqkv.weight', 'base_model.model.transformer.blocks.18.attn.out_proj.weight', 'base_model.model.transformer.blocks.18.norm_2.weight', 'base_model.model.transformer.blocks.18.ffn.up_proj.weight', 'base_model.model.transformer.blocks.18.ffn.down_proj.weight', 'base_model.model.transformer.blocks.19.norm_1.weight', 'base_model.model.transformer.blocks.19.attn.Wqkv.weight', 'base_model.model.transformer.blocks.19.attn.out_proj.weight', 'base_model.model.transformer.blocks.19.norm_2.weight', 'base_model.model.transformer.blocks.19.ffn.up_proj.weight', 'base_model.model.transformer.blocks.19.ffn.down_proj.weight', 'base_model.model.transformer.blocks.20.norm_1.weight', 'base_model.model.transformer.blocks.20.attn.Wqkv.weight', 'base_model.model.transformer.blocks.20.attn.out_proj.weight', 'base_model.model.transformer.blocks.20.norm_2.weight', 'base_model.model.transformer.blocks.20.ffn.up_proj.weight', 'base_model.model.transformer.blocks.20.ffn.down_proj.weight', 'base_model.model.transformer.blocks.21.norm_1.weight', 'base_model.model.transformer.blocks.21.attn.Wqkv.weight', 'base_model.model.transformer.blocks.21.attn.out_proj.weight', 'base_model.model.transformer.blocks.21.norm_2.weight', 'base_model.model.transformer.blocks.21.ffn.up_proj.weight', 'base_model.model.transformer.blocks.21.ffn.down_proj.weight', 'base_model.model.transformer.blocks.22.norm_1.weight', 'base_model.model.transformer.blocks.22.attn.Wqkv.weight', 'base_model.model.transformer.blocks.22.attn.out_proj.weight', 'base_model.model.transformer.blocks.22.norm_2.weight', 'base_model.model.transformer.blocks.22.ffn.up_proj.weight', 'base_model.model.transformer.blocks.22.ffn.down_proj.weight', 'base_model.model.transformer.blocks.23.norm_1.weight', 'base_model.model.transformer.blocks.23.attn.Wqkv.weight', 'base_model.model.transformer.blocks.23.attn.out_proj.weight', 'base_model.model.transformer.blocks.23.norm_2.weight', 'base_model.model.transformer.blocks.23.ffn.up_proj.weight', 'base_model.model.transformer.blocks.23.ffn.down_proj.weight', 'base_model.model.transformer.blocks.24.norm_1.weight', 'base_model.model.transformer.blocks.24.attn.Wqkv.weight', 'base_model.model.transformer.blocks.24.attn.out_proj.weight', 'base_model.model.transformer.blocks.24.norm_2.weight', 'base_model.model.transformer.blocks.24.ffn.up_proj.weight', 'base_model.model.transformer.blocks.24.ffn.down_proj.weight', 'base_model.model.transformer.blocks.25.norm_1.weight', 'base_model.model.transformer.blocks.25.attn.Wqkv.weight', 'base_model.model.transformer.blocks.25.attn.out_proj.weight', 'base_model.model.transformer.blocks.25.norm_2.weight', 'base_model.model.transformer.blocks.25.ffn.up_proj.weight', 'base_model.model.transformer.blocks.25.ffn.down_proj.weight', 'base_model.model.transformer.blocks.26.norm_1.weight', 'base_model.model.transformer.blocks.26.attn.Wqkv.weight', 'base_model.model.transformer.blocks.26.attn.out_proj.weight', 'base_model.model.transformer.blocks.26.norm_2.weight', 'base_model.model.transformer.blocks.26.ffn.up_proj.weight', 'base_model.model.transformer.blocks.26.ffn.down_proj.weight', 'base_model.model.transformer.blocks.27.norm_1.weight', 'base_model.model.transformer.blocks.27.attn.Wqkv.weight', 'base_model.model.transformer.blocks.27.attn.out_proj.weight', 'base_model.model.transformer.blocks.27.norm_2.weight', 'base_model.model.transformer.blocks.27.ffn.up_proj.weight', 'base_model.model.transformer.blocks.27.ffn.down_proj.weight', 'base_model.model.transformer.blocks.28.norm_1.weight', 'base_model.model.transformer.blocks.28.attn.Wqkv.weight', 'base_model.model.transformer.blocks.28.attn.out_proj.weight', 'base_model.model.transformer.blocks.28.norm_2.weight', 'base_model.model.transformer.blocks.28.ffn.up_proj.weight', 'base_model.model.transformer.blocks.28.ffn.down_proj.weight', 'base_model.model.transformer.blocks.29.norm_1.weight', 'base_model.model.transformer.blocks.29.attn.Wqkv.weight', 'base_model.model.transformer.blocks.29.attn.out_proj.weight', 'base_model.model.transformer.blocks.29.norm_2.weight', 'base_model.model.transformer.blocks.29.ffn.up_proj.weight', 'base_model.model.transformer.blocks.29.ffn.down_proj.weight', 'base_model.model.transformer.blocks.30.norm_1.weight', 'base_model.model.transformer.blocks.30.attn.Wqkv.weight', 'base_model.model.transformer.blocks.30.attn.out_proj.weight', 'base_model.model.transformer.blocks.30.norm_2.weight', 'base_model.model.transformer.blocks.30.ffn.up_proj.weight', 'base_model.model.transformer.blocks.30.ffn.down_proj.weight', 'base_model.model.transformer.blocks.31.norm_1.weight', 'base_model.model.transformer.blocks.31.attn.Wqkv.weight', 'base_model.model.transformer.blocks.31.attn.out_proj.weight', 'base_model.model.transformer.blocks.31.norm_2.weight', 'base_model.model.transformer.blocks.31.ffn.up_proj.weight', 'base_model.model.transformer.blocks.31.ffn.down_proj.weight', 'base_model.model.transformer.norm_f.weight'].
ERROR: Could not consume arg: -f
Missing keys error should be okay. Can you check if the adapters were saved? If so, what’s their size? Should be named adapter_model.bin and adapter_config.json.
Let me know
The loss did not decrease for me either, and it seemed to learn the patters from my training dataset well.