norm function changes dtypes
#13
by
sasawq21
- opened
No description provided.
bug when fine-tuning internvl2-8b
CUDA_VISIBLE_DEVICES=0,1,2,3 swift sft --model_type internvl2-8b --model_id_or_path /root/.cache/modelscope/hub/OpenGVLab/InternVL2-8B/ --dataset ./output/jsonl/train_dataset.jsonl --max_length 4096 --use_flash_attn true --gradient_checkpointing true --learning_rate 1e-6 --num_train_epochs=3 --gradient_accumulation_steps 64 --preprocess_num_proc 48 --quantization_bit 8 --dtype bf16
error is
...
Train: 0%| | 0/7014 [00:00<?, ?it/s]/usr/local/lib/python3.10/dist-packages/bitsandbytes/autograd/_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
Traceback (most recent call last):
File "/root/projects/ms-swift/swift/cli/sft.py", line 5, in <module>
sft_main()
File "/root/projects/ms-swift/swift/utils/run_utils.py", line 32, in x_main
result = llm_x(args, **kwargs)
File "/root/projects/ms-swift/swift/llm/sft.py", line 405, in llm_sft
trainer.train(training_args.resume_from_checkpoint)
File "/root/projects/ms-swift/swift/trainers/mixin.py", line 538, in train
res = super().train(resume_from_checkpoint, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1948, in train
return inner_training_loop(
File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2289, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 3328, in training_step
loss = self.compute_loss(model, inputs)
File "/root/projects/ms-swift/swift/trainers/trainers.py", line 179, in compute_loss
outputs = model(**inputs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1603, in _call_impl
result = forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py", line 819, in forward
return model_forward(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py", line 807, in __call__
return convert_to_fp32(self.model_forward(*args, **kwargs))
File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/peft/peft_model.py", line 1577, in forward
return self.base_model(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/peft/tuners/tuners_utils.py", line 188, in forward
return self.model.forward(*args, **kwargs)
File "/root/projects/ms-swift/swift/llm/utils/model.py", line 4093, in wrapper
return forward_func(
File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 169, in new_forward
output = module._old_forward(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/InternVL2-8B/modeling_internvl_chat.py", line 103, in forward
vit_embeds = self.extract_feature(pixel_values)
File "/root/projects/ms-swift/swift/llm/utils/model.py", line 4316, in _new_extract_feature
return extract_feature(pixel_values).to(pixel_values.device).to(pixel_values.dtype)
File "/root/.cache/huggingface/modules/transformers_modules/InternVL2-8B/modeling_internvl_chat.py", line 181, in extract_feature
vit_embeds = self.vision_model(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 169, in new_forward
output = module._old_forward(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/InternVL2-8B/modeling_intern_vit.py", line 419, in forward
encoder_outputs = self.encoder(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 169, in new_forward
output = module._old_forward(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/InternVL2-8B/modeling_intern_vit.py", line 350, in forward
layer_outputs = torch.utils.checkpoint.checkpoint(
File "/root/projects/ms-swift/swift/llm/utils/model.py", line 6224, in <lambda>
lambda *args, use_reentrant=_use_reentrant, **kwargs: _old_checkpoint(
File "/usr/local/lib/python3.10/dist-packages/torch/_compile.py", line 31, in inner
return disable_fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py", line 488, in checkpoint
ret = function(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 169, in new_forward
output = module._old_forward(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/InternVL2-8B/modeling_intern_vit.py", line 296, in forward
hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states)) * self.ls1)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 169, in new_forward
output = module._old_forward(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/InternVL2-8B/modeling_intern_vit.py", line 252, in forward
x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states)
File "/root/.cache/huggingface/modules/transformers_modules/InternVL2-8B/modeling_intern_vit.py", line 244, in _flash_attn
context, _ = self.inner_attn(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/InternVL2-8B/modeling_intern_vit.py", line 66, in forward
assert qkv.dtype in [torch.float16, torch.bfloat16]
AssertionError
hidden_states.dtype is changed to torch.float32 by nn.LayerNorm, modified the two lines to convert it back to same as input.dtype
sasawq21
changed pull request status to
open
czczup
changed pull request status to
merged
Thank you for your feedback!