Shakespeare-Coriolanus / inference.py
gitesh-grover's picture
Upload inference.py
623ca17 verified
import torch
import torch.nn as nn
from torch.nn import functional as F
from utils import get_device
from model import DecoderTransformer
def predict(x, model, max_output_len = 30):
device = get_device(seed=37)
input_len = x.size(1)
# x is of shape (B, Tr). Tr = running token size increased by 1 afer every loop below
while (x.size(1) < input_len + max_output_len):
# forward the model to get the logits
with torch.no_grad():
# TODO what is [0]?
logits = model(x)[0] # (B, Tr, vocab_size)
# take the logits at the last position as thats the prediction
logits = logits[:, -1, :] # (B, vocab_size)
# get the probabilities (from predicted vocab)
probs = F.softmax(logits, dim=-1)
# do top-k sampling of 50 (huggingface pipeline default)
# topk_probs here becomes (5, 50), topk_indices is (5, 50)
topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
# select a token from the top-k probabilities
# note: multinomial does not demand the input to sum to 1
ix = torch.multinomial(topk_probs, 1) # (B, 1)
# gather the corresponding indices
xcol = torch.gather(topk_indices, -1, ix) # (B, 1)
# append to the sequence increaing the Tr by 1
x = torch.cat((x, xcol), dim=1) # (B, Tr).. Tr = Tr+1
# Stop if end token is generated
# if xcol == config.end_token:
# break
return x[:, :] # B, max_output_len