norm function changes dtypes

#13
No description provided.

bug when fine-tuning internvl2-8b

CUDA_VISIBLE_DEVICES=0,1,2,3 swift sft --model_type internvl2-8b --model_id_or_path /root/.cache/modelscope/hub/OpenGVLab/InternVL2-8B/     --dataset ./output/jsonl/train_dataset.jsonl   --max_length 4096 --use_flash_attn true --gradient_checkpointing true --learning_rate 1e-6 --num_train_epochs=3   --gradient_accumulation_steps 64 --preprocess_num_proc 48 --quantization_bit 8     --dtype bf16

error is


...
Train:   0%|                                                                                                | 0/7014 [00:00<?, ?it/s]/usr/local/lib/python3.10/dist-packages/bitsandbytes/autograd/_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
Traceback (most recent call last):
  File "/root/projects/ms-swift/swift/cli/sft.py", line 5, in <module>
    sft_main()
  File "/root/projects/ms-swift/swift/utils/run_utils.py", line 32, in x_main
    result = llm_x(args, **kwargs)
  File "/root/projects/ms-swift/swift/llm/sft.py", line 405, in llm_sft
    trainer.train(training_args.resume_from_checkpoint)
  File "/root/projects/ms-swift/swift/trainers/mixin.py", line 538, in train
    res = super().train(resume_from_checkpoint, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1948, in train
    return inner_training_loop(
  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2289, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 3328, in training_step
    loss = self.compute_loss(model, inputs)
  File "/root/projects/ms-swift/swift/trainers/trainers.py", line 179, in compute_loss
    outputs = model(**inputs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1603, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py", line 819, in forward
    return model_forward(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py", line 807, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/peft/peft_model.py", line 1577, in forward
    return self.base_model(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/peft/tuners/tuners_utils.py", line 188, in forward
    return self.model.forward(*args, **kwargs)
  File "/root/projects/ms-swift/swift/llm/utils/model.py", line 4093, in wrapper
    return forward_func(
  File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 169, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/root/.cache/huggingface/modules/transformers_modules/InternVL2-8B/modeling_internvl_chat.py", line 103, in forward
    vit_embeds = self.extract_feature(pixel_values)
  File "/root/projects/ms-swift/swift/llm/utils/model.py", line 4316, in _new_extract_feature
    return extract_feature(pixel_values).to(pixel_values.device).to(pixel_values.dtype)
  File "/root/.cache/huggingface/modules/transformers_modules/InternVL2-8B/modeling_internvl_chat.py", line 181, in extract_feature
    vit_embeds = self.vision_model(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 169, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/root/.cache/huggingface/modules/transformers_modules/InternVL2-8B/modeling_intern_vit.py", line 419, in forward
    encoder_outputs = self.encoder(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 169, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/root/.cache/huggingface/modules/transformers_modules/InternVL2-8B/modeling_intern_vit.py", line 350, in forward
    layer_outputs = torch.utils.checkpoint.checkpoint(
  File "/root/projects/ms-swift/swift/llm/utils/model.py", line 6224, in <lambda>
    lambda *args, use_reentrant=_use_reentrant, **kwargs: _old_checkpoint(
  File "/usr/local/lib/python3.10/dist-packages/torch/_compile.py", line 31, in inner
    return disable_fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py", line 488, in checkpoint
    ret = function(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 169, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/root/.cache/huggingface/modules/transformers_modules/InternVL2-8B/modeling_intern_vit.py", line 296, in forward
    hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states)) * self.ls1)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 169, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/root/.cache/huggingface/modules/transformers_modules/InternVL2-8B/modeling_intern_vit.py", line 252, in forward
    x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states)
  File "/root/.cache/huggingface/modules/transformers_modules/InternVL2-8B/modeling_intern_vit.py", line 244, in _flash_attn
    context, _ = self.inner_attn(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/.cache/huggingface/modules/transformers_modules/InternVL2-8B/modeling_intern_vit.py", line 66, in forward
    assert qkv.dtype in [torch.float16, torch.bfloat16]
AssertionError

hidden_states.dtype is changed to torch.float32 by nn.LayerNorm, modified the two lines to convert it back to same as input.dtype

sasawq21 changed pull request status to open
czczup changed pull request status to merged
OpenGVLab org

Thank you for your feedback!

Sign up or log in to comment