joaogante HF staff commited on
Commit
2701290
·
verified ·
1 Parent(s): 3e8ad6e

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +3 -1
generate.py CHANGED
@@ -1,6 +1,8 @@
1
  import torch
2
 
3
- def generate(model, input_ids, generation_config, **kwargs):
 
 
4
  generation_config = generation_config or model.generation_config # default to the model generation config
5
  cur_length = input_ids.shape[1]
6
  max_length = generation_config.max_length or cur_length + generation_config.max_new_tokens
 
1
  import torch
2
 
3
+ def generate(model, generation_config, **kwargs):
4
+ input_ids = kwargs.get("input_ids") or kwargs.get("input")
5
+
6
  generation_config = generation_config or model.generation_config # default to the model generation config
7
  cur_length = input_ids.shape[1]
8
  max_length = generation_config.max_length or cur_length + generation_config.max_new_tokens