joaogante HF staff commited on
Commit
a9cda06
·
verified ·
1 Parent(s): 2701290

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +1 -1
generate.py CHANGED
@@ -1,7 +1,7 @@
1
  import torch
2
 
3
  def generate(model, generation_config, **kwargs):
4
- input_ids = kwargs.get("input_ids") or kwargs.get("input")
5
 
6
  generation_config = generation_config or model.generation_config # default to the model generation config
7
  cur_length = input_ids.shape[1]
 
1
  import torch
2
 
3
  def generate(model, generation_config, **kwargs):
4
+ input_ids = kwargs.get("input_ids") or kwargs.get("inputs")
5
 
6
  generation_config = generation_config or model.generation_config # default to the model generation config
7
  cur_length = input_ids.shape[1]