joaogante HF staff commited on
Commit
299d479
·
verified ·
1 Parent(s): a9cda06

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +1 -3
generate.py CHANGED
@@ -1,8 +1,6 @@
1
  import torch
2
 
3
- def generate(model, generation_config, **kwargs):
4
- input_ids = kwargs.get("input_ids") or kwargs.get("inputs")
5
-
6
  generation_config = generation_config or model.generation_config # default to the model generation config
7
  cur_length = input_ids.shape[1]
8
  max_length = generation_config.max_length or cur_length + generation_config.max_new_tokens
 
1
  import torch
2
 
3
+ def generate(model, input_ids, generation_config, **kwargs):
 
 
4
  generation_config = generation_config or model.generation_config # default to the model generation config
5
  cur_length = input_ids.shape[1]
6
  max_length = generation_config.max_length or cur_length + generation_config.max_new_tokens