Support gradient checkpointing
#41
by
muelletm
- opened
(This is a trace you get when trying to use mpt-7b from qlora. More details here: https://github.com/artidoro/qlora/issues/10)
/opt/conda/lib/python3.10/site-packages/peft/utils/other.py:76: FutureWarning: prepare_model_for_int8_training is deprecated and will be removed in a future version. Use prepare_model_for_kbit_training instead.
warnings.warn(
Traceback (most recent call last):
File "/code/qlora/qlora.py", line 758, in <module>
train()
File "/code/qlora/qlora.py", line 590, in train
model = get_accelerate_model(args, checkpoint_dir)
File "/code/qlora/qlora.py", line 295, in get_accelerate_model
model = prepare_model_for_int8_training(model, use_gradient_checkpointing=args.gradient_checkpointing)
File "/opt/conda/lib/python3.10/site-packages/peft/utils/other.py", line 80, in prepare_model_for_int8_training
return prepare_model_for_kbit_training(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/peft/utils/other.py", line 69, in prepare_model_for_kbit_training
model.gradient_checkpointing_enable()
File "/opt/conda/lib/python3.10/site-packages/transformers/modeling_utils.py", line 1620, in gradient_checkpointing_enable
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
Small update but @cekal added a version of mpt-7b that fixes this (and other problems):
https://huggingface.co/cekal/mpt-7b-peft-compatible
Maybe their code can be merged into this repo?
this looks like the change to enable gradient checkpointing?
hmm, ok, but was that propagated to other mpt models like mpt-7b-chat and mpt-7b-8k-chat?
thanks, noted, feel free to close the issue