Avoiding RuntimeError when loading in 4-bit / 8-bit

#28

Convert to HalfTensor as the model expects it

model_inputs["pixel_values"] = model_inputs["pixel_values"].half()

This avoid this error: RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.cuda.HalfTensor) should be the same

Ready to merge
This branch is ready to get merged automatically.

Sign up or log in to comment