Update generate.py
Browse files- generate.py +1 -3
generate.py
CHANGED
@@ -1,8 +1,6 @@
|
|
1 |
import torch
|
2 |
|
3 |
-
def generate(model, generation_config, **kwargs):
|
4 |
-
input_ids = kwargs.get("input_ids") or kwargs.get("inputs")
|
5 |
-
|
6 |
generation_config = generation_config or model.generation_config # default to the model generation config
|
7 |
cur_length = input_ids.shape[1]
|
8 |
max_length = generation_config.max_length or cur_length + generation_config.max_new_tokens
|
|
|
1 |
import torch
|
2 |
|
3 |
+
def generate(model, input_ids, generation_config, **kwargs):
|
|
|
|
|
4 |
generation_config = generation_config or model.generation_config # default to the model generation config
|
5 |
cur_length = input_ids.shape[1]
|
6 |
max_length = generation_config.max_length or cur_length + generation_config.max_new_tokens
|