test_generate_from_hub / generate.py
joaogante's picture
joaogante HF staff
Update generate.py
299d479 verified
import torch
def generate(model, input_ids, generation_config, **kwargs):
generation_config = generation_config or model.generation_config # default to the model generation config
cur_length = input_ids.shape[1]
max_length = generation_config.max_length or cur_length + generation_config.max_new_tokens
while cur_length < max_length:
logits = model(input_ids).logits
next_token_logits = logits[:, -1, :]
next_tokens = torch.argmax(next_token_logits, dim=-1)
input_ids = torch.cat((input_ids, next_tokens[:, None]), dim=-1)
cur_length += 1
return input_ids