PEFT Finetuning Code Please

#3
by kdua - opened

Hello
Thank you for making this work.
I wanted to check if you could please share the finetuning code using PEFT? I am running into an error and suspect that there may be issues in my finetuning code. Following is the code I am using:

model = AutoModelForCausalLM.from_pretrained('./mpt-7b-peft-compatible', trust_remote_code=True, torch_dtype=torch.bfloat16)

from torch import nn

for param in model.parameters():
param.requires_grad = False # freeze the model - train adapters later

model.gradient_checkpointing_enable() # reduce number of stored activations
model.enable_input_require_grads()

def print_trainable_parameters(model):
"""
Prints the number of trainable parameters in the model.
"""
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
print(
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
)

from peft import LoraConfig, get_peft_model

config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["up_proj", "down_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)

model = get_peft_model(model, config)
print_trainable_parameters(model)

output_dir = 'mpt-finetuned'

trainer = Trainer(
model=model,
train_dataset=train_ds_tokenized,
args=TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
warmup_steps=100,
max_steps=200,
learning_rate=2e-4,
fp16=True,
logging_steps=1,
output_dir=output_dir
),
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False)
)

trainer.train()

I get this error on trainer.train(): AttributeError: 'MPTForCausalLM' object has no attribute 'model_parallel'

Hi, try this notebook: https://colab.research.google.com/drive/1iBeY5UTLHE3aL6yNLiCIJHOBDqWBYbi5?usp=sharing

Works fine for me. If you'll still encounter issues, try updating some packages or let me know with the specific error.

I got this error: AttributeError: 'MPTForCausalLM' object has no attribute 'model_parallel'.
@cekal do you have any idea ?

Even when using this notebook? https://colab.research.google.com/drive/1iBeY5UTLHE3aL6yNLiCIJHOBDqWBYbi5?usp=sharing
Worked fine for me on Google colab.

Yes, I run this notebook directly on Google Colab Pro.
image.png
Is it a problem when the runtime is A100 GPU ?

Getting exactly the same error: 'MPTForCausalLM' object has no attribute 'model_parallel'
I am using a custom V100 setup

@envyt48 @cekal
I managed to solve the problem. You need to move the following code (already present in the above shared notebook at line 243-246) to before LoraConfig (line 192) initialization. Basically, the peft model's model_parallel attribute was being set but not for the underlying model. We need to set it for the underlying model

if not ddp and torch.cuda.device_count() > 1:
# keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
model.is_parallelizable = True
model.model_parallel = True

This is a copy of the edited notebook and this should work now: https://colab.research.google.com/drive/1A9bnjSfRQg6GciIkSxbfX7vS8PuP83FN

I ran the above notebook (had to modify the code that @kdua mentioned by explicitly setting the two parallel props to False (running on a single Colab GPU). Ran for about 5 hours and came up with an error trying to save but more concerning, the loss did not decrease. See screenshot:

Screenshot 2023-06-07 at 4.11.40 AM.png

The full error text is:
There were missing keys in the checkpoint model loaded: ['base_model.model.transformer.wte.weight', 'base_model.model.transformer.blocks.0.norm_1.weight', 'base_model.model.transformer.blocks.0.attn.Wqkv.weight', 'base_model.model.transformer.blocks.0.attn.out_proj.weight', 'base_model.model.transformer.blocks.0.norm_2.weight', 'base_model.model.transformer.blocks.0.ffn.up_proj.weight', 'base_model.model.transformer.blocks.0.ffn.down_proj.weight', 'base_model.model.transformer.blocks.1.norm_1.weight', 'base_model.model.transformer.blocks.1.attn.Wqkv.weight', 'base_model.model.transformer.blocks.1.attn.out_proj.weight', 'base_model.model.transformer.blocks.1.norm_2.weight', 'base_model.model.transformer.blocks.1.ffn.up_proj.weight', 'base_model.model.transformer.blocks.1.ffn.down_proj.weight', 'base_model.model.transformer.blocks.2.norm_1.weight', 'base_model.model.transformer.blocks.2.attn.Wqkv.weight', 'base_model.model.transformer.blocks.2.attn.out_proj.weight', 'base_model.model.transformer.blocks.2.norm_2.weight', 'base_model.model.transformer.blocks.2.ffn.up_proj.weight', 'base_model.model.transformer.blocks.2.ffn.down_proj.weight', 'base_model.model.transformer.blocks.3.norm_1.weight', 'base_model.model.transformer.blocks.3.attn.Wqkv.weight', 'base_model.model.transformer.blocks.3.attn.out_proj.weight', 'base_model.model.transformer.blocks.3.norm_2.weight', 'base_model.model.transformer.blocks.3.ffn.up_proj.weight', 'base_model.model.transformer.blocks.3.ffn.down_proj.weight', 'base_model.model.transformer.blocks.4.norm_1.weight', 'base_model.model.transformer.blocks.4.attn.Wqkv.weight', 'base_model.model.transformer.blocks.4.attn.out_proj.weight', 'base_model.model.transformer.blocks.4.norm_2.weight', 'base_model.model.transformer.blocks.4.ffn.up_proj.weight', 'base_model.model.transformer.blocks.4.ffn.down_proj.weight', 'base_model.model.transformer.blocks.5.norm_1.weight', 'base_model.model.transformer.blocks.5.attn.Wqkv.weight', 'base_model.model.transformer.blocks.5.attn.out_proj.weight', 'base_model.model.transformer.blocks.5.norm_2.weight', 'base_model.model.transformer.blocks.5.ffn.up_proj.weight', 'base_model.model.transformer.blocks.5.ffn.down_proj.weight', 'base_model.model.transformer.blocks.6.norm_1.weight', 'base_model.model.transformer.blocks.6.attn.Wqkv.weight', 'base_model.model.transformer.blocks.6.attn.out_proj.weight', 'base_model.model.transformer.blocks.6.norm_2.weight', 'base_model.model.transformer.blocks.6.ffn.up_proj.weight', 'base_model.model.transformer.blocks.6.ffn.down_proj.weight', 'base_model.model.transformer.blocks.7.norm_1.weight', 'base_model.model.transformer.blocks.7.attn.Wqkv.weight', 'base_model.model.transformer.blocks.7.attn.out_proj.weight', 'base_model.model.transformer.blocks.7.norm_2.weight', 'base_model.model.transformer.blocks.7.ffn.up_proj.weight', 'base_model.model.transformer.blocks.7.ffn.down_proj.weight', 'base_model.model.transformer.blocks.8.norm_1.weight', 'base_model.model.transformer.blocks.8.attn.Wqkv.weight', 'base_model.model.transformer.blocks.8.attn.out_proj.weight', 'base_model.model.transformer.blocks.8.norm_2.weight', 'base_model.model.transformer.blocks.8.ffn.up_proj.weight', 'base_model.model.transformer.blocks.8.ffn.down_proj.weight', 'base_model.model.transformer.blocks.9.norm_1.weight', 'base_model.model.transformer.blocks.9.attn.Wqkv.weight', 'base_model.model.transformer.blocks.9.attn.out_proj.weight', 'base_model.model.transformer.blocks.9.norm_2.weight', 'base_model.model.transformer.blocks.9.ffn.up_proj.weight', 'base_model.model.transformer.blocks.9.ffn.down_proj.weight', 'base_model.model.transformer.blocks.10.norm_1.weight', 'base_model.model.transformer.blocks.10.attn.Wqkv.weight', 'base_model.model.transformer.blocks.10.attn.out_proj.weight', 'base_model.model.transformer.blocks.10.norm_2.weight', 'base_model.model.transformer.blocks.10.ffn.up_proj.weight', 'base_model.model.transformer.blocks.10.ffn.down_proj.weight', 'base_model.model.transformer.blocks.11.norm_1.weight', 'base_model.model.transformer.blocks.11.attn.Wqkv.weight', 'base_model.model.transformer.blocks.11.attn.out_proj.weight', 'base_model.model.transformer.blocks.11.norm_2.weight', 'base_model.model.transformer.blocks.11.ffn.up_proj.weight', 'base_model.model.transformer.blocks.11.ffn.down_proj.weight', 'base_model.model.transformer.blocks.12.norm_1.weight', 'base_model.model.transformer.blocks.12.attn.Wqkv.weight', 'base_model.model.transformer.blocks.12.attn.out_proj.weight', 'base_model.model.transformer.blocks.12.norm_2.weight', 'base_model.model.transformer.blocks.12.ffn.up_proj.weight', 'base_model.model.transformer.blocks.12.ffn.down_proj.weight', 'base_model.model.transformer.blocks.13.norm_1.weight', 'base_model.model.transformer.blocks.13.attn.Wqkv.weight', 'base_model.model.transformer.blocks.13.attn.out_proj.weight', 'base_model.model.transformer.blocks.13.norm_2.weight', 'base_model.model.transformer.blocks.13.ffn.up_proj.weight', 'base_model.model.transformer.blocks.13.ffn.down_proj.weight', 'base_model.model.transformer.blocks.14.norm_1.weight', 'base_model.model.transformer.blocks.14.attn.Wqkv.weight', 'base_model.model.transformer.blocks.14.attn.out_proj.weight', 'base_model.model.transformer.blocks.14.norm_2.weight', 'base_model.model.transformer.blocks.14.ffn.up_proj.weight', 'base_model.model.transformer.blocks.14.ffn.down_proj.weight', 'base_model.model.transformer.blocks.15.norm_1.weight', 'base_model.model.transformer.blocks.15.attn.Wqkv.weight', 'base_model.model.transformer.blocks.15.attn.out_proj.weight', 'base_model.model.transformer.blocks.15.norm_2.weight', 'base_model.model.transformer.blocks.15.ffn.up_proj.weight', 'base_model.model.transformer.blocks.15.ffn.down_proj.weight', 'base_model.model.transformer.blocks.16.norm_1.weight', 'base_model.model.transformer.blocks.16.attn.Wqkv.weight', 'base_model.model.transformer.blocks.16.attn.out_proj.weight', 'base_model.model.transformer.blocks.16.norm_2.weight', 'base_model.model.transformer.blocks.16.ffn.up_proj.weight', 'base_model.model.transformer.blocks.16.ffn.down_proj.weight', 'base_model.model.transformer.blocks.17.norm_1.weight', 'base_model.model.transformer.blocks.17.attn.Wqkv.weight', 'base_model.model.transformer.blocks.17.attn.out_proj.weight', 'base_model.model.transformer.blocks.17.norm_2.weight', 'base_model.model.transformer.blocks.17.ffn.up_proj.weight', 'base_model.model.transformer.blocks.17.ffn.down_proj.weight', 'base_model.model.transformer.blocks.18.norm_1.weight', 'base_model.model.transformer.blocks.18.attn.Wqkv.weight', 'base_model.model.transformer.blocks.18.attn.out_proj.weight', 'base_model.model.transformer.blocks.18.norm_2.weight', 'base_model.model.transformer.blocks.18.ffn.up_proj.weight', 'base_model.model.transformer.blocks.18.ffn.down_proj.weight', 'base_model.model.transformer.blocks.19.norm_1.weight', 'base_model.model.transformer.blocks.19.attn.Wqkv.weight', 'base_model.model.transformer.blocks.19.attn.out_proj.weight', 'base_model.model.transformer.blocks.19.norm_2.weight', 'base_model.model.transformer.blocks.19.ffn.up_proj.weight', 'base_model.model.transformer.blocks.19.ffn.down_proj.weight', 'base_model.model.transformer.blocks.20.norm_1.weight', 'base_model.model.transformer.blocks.20.attn.Wqkv.weight', 'base_model.model.transformer.blocks.20.attn.out_proj.weight', 'base_model.model.transformer.blocks.20.norm_2.weight', 'base_model.model.transformer.blocks.20.ffn.up_proj.weight', 'base_model.model.transformer.blocks.20.ffn.down_proj.weight', 'base_model.model.transformer.blocks.21.norm_1.weight', 'base_model.model.transformer.blocks.21.attn.Wqkv.weight', 'base_model.model.transformer.blocks.21.attn.out_proj.weight', 'base_model.model.transformer.blocks.21.norm_2.weight', 'base_model.model.transformer.blocks.21.ffn.up_proj.weight', 'base_model.model.transformer.blocks.21.ffn.down_proj.weight', 'base_model.model.transformer.blocks.22.norm_1.weight', 'base_model.model.transformer.blocks.22.attn.Wqkv.weight', 'base_model.model.transformer.blocks.22.attn.out_proj.weight', 'base_model.model.transformer.blocks.22.norm_2.weight', 'base_model.model.transformer.blocks.22.ffn.up_proj.weight', 'base_model.model.transformer.blocks.22.ffn.down_proj.weight', 'base_model.model.transformer.blocks.23.norm_1.weight', 'base_model.model.transformer.blocks.23.attn.Wqkv.weight', 'base_model.model.transformer.blocks.23.attn.out_proj.weight', 'base_model.model.transformer.blocks.23.norm_2.weight', 'base_model.model.transformer.blocks.23.ffn.up_proj.weight', 'base_model.model.transformer.blocks.23.ffn.down_proj.weight', 'base_model.model.transformer.blocks.24.norm_1.weight', 'base_model.model.transformer.blocks.24.attn.Wqkv.weight', 'base_model.model.transformer.blocks.24.attn.out_proj.weight', 'base_model.model.transformer.blocks.24.norm_2.weight', 'base_model.model.transformer.blocks.24.ffn.up_proj.weight', 'base_model.model.transformer.blocks.24.ffn.down_proj.weight', 'base_model.model.transformer.blocks.25.norm_1.weight', 'base_model.model.transformer.blocks.25.attn.Wqkv.weight', 'base_model.model.transformer.blocks.25.attn.out_proj.weight', 'base_model.model.transformer.blocks.25.norm_2.weight', 'base_model.model.transformer.blocks.25.ffn.up_proj.weight', 'base_model.model.transformer.blocks.25.ffn.down_proj.weight', 'base_model.model.transformer.blocks.26.norm_1.weight', 'base_model.model.transformer.blocks.26.attn.Wqkv.weight', 'base_model.model.transformer.blocks.26.attn.out_proj.weight', 'base_model.model.transformer.blocks.26.norm_2.weight', 'base_model.model.transformer.blocks.26.ffn.up_proj.weight', 'base_model.model.transformer.blocks.26.ffn.down_proj.weight', 'base_model.model.transformer.blocks.27.norm_1.weight', 'base_model.model.transformer.blocks.27.attn.Wqkv.weight', 'base_model.model.transformer.blocks.27.attn.out_proj.weight', 'base_model.model.transformer.blocks.27.norm_2.weight', 'base_model.model.transformer.blocks.27.ffn.up_proj.weight', 'base_model.model.transformer.blocks.27.ffn.down_proj.weight', 'base_model.model.transformer.blocks.28.norm_1.weight', 'base_model.model.transformer.blocks.28.attn.Wqkv.weight', 'base_model.model.transformer.blocks.28.attn.out_proj.weight', 'base_model.model.transformer.blocks.28.norm_2.weight', 'base_model.model.transformer.blocks.28.ffn.up_proj.weight', 'base_model.model.transformer.blocks.28.ffn.down_proj.weight', 'base_model.model.transformer.blocks.29.norm_1.weight', 'base_model.model.transformer.blocks.29.attn.Wqkv.weight', 'base_model.model.transformer.blocks.29.attn.out_proj.weight', 'base_model.model.transformer.blocks.29.norm_2.weight', 'base_model.model.transformer.blocks.29.ffn.up_proj.weight', 'base_model.model.transformer.blocks.29.ffn.down_proj.weight', 'base_model.model.transformer.blocks.30.norm_1.weight', 'base_model.model.transformer.blocks.30.attn.Wqkv.weight', 'base_model.model.transformer.blocks.30.attn.out_proj.weight', 'base_model.model.transformer.blocks.30.norm_2.weight', 'base_model.model.transformer.blocks.30.ffn.up_proj.weight', 'base_model.model.transformer.blocks.30.ffn.down_proj.weight', 'base_model.model.transformer.blocks.31.norm_1.weight', 'base_model.model.transformer.blocks.31.attn.Wqkv.weight', 'base_model.model.transformer.blocks.31.attn.out_proj.weight', 'base_model.model.transformer.blocks.31.norm_2.weight', 'base_model.model.transformer.blocks.31.ffn.up_proj.weight', 'base_model.model.transformer.blocks.31.ffn.down_proj.weight', 'base_model.model.transformer.norm_f.weight'].
ERROR: Could not consume arg: -f

Missing keys error should be okay. Can you check if the adapters were saved? If so, what’s their size? Should be named adapter_model.bin and adapter_config.json.

Let me know

cekal changed discussion status to closed
cekal changed discussion status to open

The loss did not decrease for me either, and it seemed to learn the patters from my training dataset well.

kdua changed discussion status to closed

Sign up or log in to comment