import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
from utils import get_device | |
from model import DecoderTransformer | |
def predict(x, model, max_output_len = 30): | |
device = get_device(seed=37) | |
input_len = x.size(1) | |
# x is of shape (B, Tr). Tr = running token size increased by 1 afer every loop below | |
while (x.size(1) < input_len + max_output_len): | |
# forward the model to get the logits | |
with torch.no_grad(): | |
# TODO what is [0]? | |
logits = model(x)[0] # (B, Tr, vocab_size) | |
# take the logits at the last position as thats the prediction | |
logits = logits[:, -1, :] # (B, vocab_size) | |
# get the probabilities (from predicted vocab) | |
probs = F.softmax(logits, dim=-1) | |
# do top-k sampling of 50 (huggingface pipeline default) | |
# topk_probs here becomes (5, 50), topk_indices is (5, 50) | |
topk_probs, topk_indices = torch.topk(probs, 50, dim=-1) | |
# select a token from the top-k probabilities | |
# note: multinomial does not demand the input to sum to 1 | |
ix = torch.multinomial(topk_probs, 1) # (B, 1) | |
# gather the corresponding indices | |
xcol = torch.gather(topk_indices, -1, ix) # (B, 1) | |
# append to the sequence increaing the Tr by 1 | |
x = torch.cat((x, xcol), dim=1) # (B, Tr).. Tr = Tr+1 | |
# Stop if end token is generated | |
# if xcol == config.end_token: | |
# break | |
return x[:, :] # B, max_output_len | |