|
import torch |
|
|
|
def generate(model, input_ids, generation_config, **kwargs): |
|
generation_config = generation_config or model.generation_config |
|
cur_length = input_ids.shape[1] |
|
max_length = generation_config.max_length or cur_length + generation_config.max_new_tokens |
|
|
|
while cur_length < max_length: |
|
logits = model(input_ids).logits |
|
next_token_logits = logits[:, -1, :] |
|
next_tokens = torch.argmax(next_token_logits, dim=-1) |
|
input_ids = torch.cat((input_ids, next_tokens[:, None]), dim=-1) |
|
cur_length += 1 |
|
|
|
return input_ids |
|
|