Update generate.py
Browse files- generate.py +1 -1
generate.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import torch
|
2 |
|
3 |
def generate(model, generation_config, **kwargs):
|
4 |
-
input_ids = kwargs.get("input_ids") or kwargs.get("
|
5 |
|
6 |
generation_config = generation_config or model.generation_config # default to the model generation config
|
7 |
cur_length = input_ids.shape[1]
|
|
|
1 |
import torch
|
2 |
|
3 |
def generate(model, generation_config, **kwargs):
|
4 |
+
input_ids = kwargs.get("input_ids") or kwargs.get("inputs")
|
5 |
|
6 |
generation_config = generation_config or model.generation_config # default to the model generation config
|
7 |
cur_length = input_ids.shape[1]
|