TypeError in PyTorch Model: Unexpected Keyword Argument 'num_logits_to_keep' in Custom Model Generation
TypeError Traceback (most recent call last)
in <cell line: 26>()
25
26 with torch.inference_mode():
---> 27 generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
28 generation = generation[0][input_len:]
29 decoded = processor.decode(generation, skip_special_tokens=True)
5 frames
/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py in new_forward(module, *args, **kwargs)
168 output = module._old_forward(*args, **kwargs)
169 else:
--> 170 output = module._old_forward(*args, **kwargs)
171 return module._hf_hook.post_forward(module, output)
172
TypeError: PaliGemmaForConditionalGeneration.forward() got an unexpected keyword argument 'num_logits_to_keep'
Context:
I'm encountering this error while trying to generate text using a custom model in PyTorch. The model code is using the `torch.inference_mode()` context, and I attempted to call `model.generate()` with specific arguments. However, it seems that the error is related to an unexpected keyword argument `num_logits_to_keep` being passed to the `forward()` method of the `PaliGemmaForConditionalGeneration` model.
Questions:
1. Has anyone encountered this issue before with custom PyTorch models or when using the `generate()` method?
2. Is there something wrong with the way I'm passing arguments to `model.generate()`? How can I resolve this `TypeError`?
Any insights or suggestions would be greatly appreciated!
Hi @Animeshsoulai ,
Make sure that the forward() method signature includes '**kwargs' or is explicitly designed to accept the 'num_logits_to_keep' argument if needed.
The 'generate()' method usually calls the 'forward()' method, passing the required arguments. Ensure that you're not passing unsupported arguments to 'generate()'. It seems like 'num_logits_to_keep' is a custom argument, and the default 'PaliGemmaForConditionalGeneration' class likely doesn't expect it.
Before calling the 'generate()' method, print 'model_inputs.keys()' to see the available keys or arguments. This will help determine if 'num_logits_to_keep' is present.
Finally, update the 'generate()' call by either removing 'num_logits_to_keep' from model_inputs or modifying the model to properly handle this argument.
If the issue still persists, please share the code files so we can assist you more effectively.
Thank you.