import torch def generate(model, input_ids, generation_config, **kwargs): generation_config = generation_config or model.generation_config # default to the model generation config cur_length = input_ids.shape[1] max_length = generation_config.max_length or cur_length + generation_config.max_new_tokens while cur_length < max_length: logits = model(input_ids).logits next_token_logits = logits[:, -1, :] next_tokens = torch.argmax(next_token_logits, dim=-1) input_ids = torch.cat((input_ids, next_tokens[:, None]), dim=-1) cur_length += 1 return input_ids