TypeError in PyTorch Model: Unexpected Keyword Argument 'num_logits_to_keep' in Custom Model Generation

#14
by Animeshsoulai - opened

TypeError Traceback (most recent call last)
in <cell line: 26>()
25
26 with torch.inference_mode():
---> 27 generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
28 generation = generation[0][input_len:]
29 decoded = processor.decode(generation, skip_special_tokens=True)

5 frames
/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py in new_forward(module, *args, **kwargs)
168 output = module._old_forward(*args, **kwargs)
169 else:
--> 170 output = module._old_forward(*args, **kwargs)
171 return module._hf_hook.post_forward(module, output)
172

TypeError: PaliGemmaForConditionalGeneration.forward() got an unexpected keyword argument 'num_logits_to_keep'


Context:
I'm encountering this error while trying to generate text using a custom model in PyTorch. The model code is using the `torch.inference_mode()` context, and I attempted to call `model.generate()` with specific arguments. However, it seems that the error is related to an unexpected keyword argument `num_logits_to_keep` being passed to the `forward()` method of the `PaliGemmaForConditionalGeneration` model.

Questions:
1. Has anyone encountered this issue before with custom PyTorch models or when using the `generate()` method?
2. Is there something wrong with the way I'm passing arguments to `model.generate()`? How can I resolve this `TypeError`?

Any insights or suggestions would be greatly appreciated!
Google org

Hi @Animeshsoulai ,

Make sure that the forward() method signature includes '**kwargs' or is explicitly designed to accept the 'num_logits_to_keep' argument if needed.

The 'generate()' method usually calls the 'forward()' method, passing the required arguments. Ensure that you're not passing unsupported arguments to 'generate()'. It seems like 'num_logits_to_keep' is a custom argument, and the default 'PaliGemmaForConditionalGeneration' class likely doesn't expect it.

Before calling the 'generate()' method, print 'model_inputs.keys()' to see the available keys or arguments. This will help determine if 'num_logits_to_keep' is present.

Finally, update the 'generate()' call by either removing 'num_logits_to_keep' from model_inputs or modifying the model to properly handle this argument.

If the issue still persists, please share the code files so we can assist you more effectively.

Thank you.

Sign up or log in to comment