|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.nn.modules.utils import _single
|
|
import utils
|
|
from multihead_attention import MultiheadAttention
|
|
import numpy as np
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
import copy
|
|
|
|
|
|
def make_positions(tensor, padding_idx, left_pad):
|
|
"""Replace non-padding symbols with their position numbers.
|
|
Position numbers begin at padding_idx+1.
|
|
Padding symbols are ignored, but it is necessary to specify whether padding
|
|
is added on the left side (left_pad=True) or right side (left_pad=False).
|
|
"""
|
|
|
|
|
|
max_pos = padding_idx + 1 + tensor.size(1)
|
|
|
|
range_buf = tensor.new()
|
|
|
|
if range_buf.numel() < max_pos:
|
|
torch.arange(padding_idx + 1, max_pos, out=range_buf)
|
|
mask = tensor.ne(padding_idx)
|
|
positions = range_buf[:tensor.size(1)].expand_as(tensor)
|
|
if left_pad:
|
|
positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1)
|
|
|
|
out = tensor.clone()
|
|
out = out.masked_scatter_(mask,positions[mask])
|
|
return out
|
|
|
|
|
|
class LearnedPositionalEmbedding(nn.Embedding):
|
|
"""This module learns positional embeddings up to a fixed maximum size.
|
|
Padding symbols are ignored, but it is necessary to specify whether padding
|
|
is added on the left side (left_pad=True) or right side (left_pad=False).
|
|
"""
|
|
|
|
def __init__(self, num_embeddings, embedding_dim, padding_idx, left_pad):
|
|
super().__init__(num_embeddings, embedding_dim, padding_idx)
|
|
self.left_pad = left_pad
|
|
nn.init.normal_(self.weight, mean=0, std=embedding_dim ** -0.5)
|
|
|
|
def forward(self, input, incremental_state=None):
|
|
"""Input is expected to be of size [bsz x seqlen]."""
|
|
if incremental_state is not None:
|
|
|
|
|
|
positions = input.data.new(1, 1).fill_(self.padding_idx + input.size(1))
|
|
else:
|
|
|
|
positions = make_positions(input.data, self.padding_idx, self.left_pad)
|
|
return super().forward(positions)
|
|
|
|
def max_positions(self):
|
|
"""Maximum number of supported positions."""
|
|
return self.num_embeddings - self.padding_idx - 1
|
|
|
|
class SinusoidalPositionalEmbedding(nn.Module):
|
|
"""This module produces sinusoidal positional embeddings of any length.
|
|
Padding symbols are ignored, but it is necessary to specify whether padding
|
|
is added on the left side (left_pad=True) or right side (left_pad=False).
|
|
"""
|
|
|
|
def __init__(self, embedding_dim, padding_idx, left_pad, init_size=1024):
|
|
super().__init__()
|
|
self.embedding_dim = embedding_dim
|
|
self.padding_idx = padding_idx
|
|
self.left_pad = left_pad
|
|
self.weights = SinusoidalPositionalEmbedding.get_embedding(
|
|
init_size,
|
|
embedding_dim,
|
|
padding_idx,
|
|
)
|
|
self.register_buffer('_float_tensor', torch.FloatTensor())
|
|
|
|
@staticmethod
|
|
def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
|
|
"""Build sinusoidal embeddings.
|
|
This matches the implementation in tensor2tensor, but differs slightly
|
|
from the description in Section 3.5 of "Attention Is All You Need".
|
|
"""
|
|
half_dim = embedding_dim // 2
|
|
emb = math.log(10000) / (half_dim - 1)
|
|
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
|
|
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
|
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
|
|
if embedding_dim % 2 == 1:
|
|
|
|
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
|
|
if padding_idx is not None:
|
|
emb[padding_idx, :] = 0
|
|
return emb
|
|
|
|
def forward(self, input, incremental_state=None):
|
|
"""Input is expected to be of size [bsz x seqlen]."""
|
|
|
|
bsz, seq_len = input.size()
|
|
max_pos = self.padding_idx + 1 + seq_len
|
|
if self.weights is None or max_pos > self.weights.size(0):
|
|
self.weights = SinusoidalPositionalEmbedding.get_embedding(
|
|
max_pos,
|
|
self.embedding_dim,
|
|
self.padding_idx,
|
|
)
|
|
self.weights = self.weights.type_as(self._float_tensor)
|
|
|
|
if incremental_state is not None:
|
|
|
|
return self.weights[self.padding_idx + seq_len, :].expand(bsz, 1, -1)
|
|
|
|
positions = make_positions(input.data, self.padding_idx, self.left_pad)
|
|
return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
|
|
|
|
def max_positions(self):
|
|
"""Maximum number of supported positions."""
|
|
return int(1e5)
|
|
|
|
class TransformerDecoderLayer(nn.Module):
|
|
"""Decoder layer block."""
|
|
|
|
def __init__(self, embed_dim, n_att, dropout=0.5, normalize_before=True, last_ln=False):
|
|
super().__init__()
|
|
|
|
self.embed_dim = embed_dim
|
|
self.dropout = dropout
|
|
self.relu_dropout = dropout
|
|
self.normalize_before = normalize_before
|
|
num_layer_norm = 3
|
|
|
|
|
|
self.self_attn = MultiheadAttention(
|
|
self.embed_dim, n_att,
|
|
dropout=dropout,
|
|
)
|
|
|
|
self.cond_att = MultiheadAttention(
|
|
self.embed_dim, n_att,
|
|
dropout=dropout,
|
|
)
|
|
|
|
self.fc1 = Linear(self.embed_dim, self.embed_dim)
|
|
self.fc2 = Linear(self.embed_dim, self.embed_dim)
|
|
self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for i in range(num_layer_norm)])
|
|
self.use_last_ln = last_ln
|
|
if self.use_last_ln:
|
|
self.last_ln = LayerNorm(self.embed_dim)
|
|
|
|
def forward(self, x, ingr_features, ingr_mask, incremental_state, img_features):
|
|
|
|
|
|
residual = x
|
|
x = self.maybe_layer_norm(0, x, before=True)
|
|
x, _ = self.self_attn(
|
|
query=x,
|
|
key=x,
|
|
value=x,
|
|
mask_future_timesteps=True,
|
|
incremental_state=incremental_state,
|
|
need_weights=False,
|
|
)
|
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
|
x = residual + x
|
|
x = self.maybe_layer_norm(0, x, after=True)
|
|
|
|
residual = x
|
|
x = self.maybe_layer_norm(1, x, before=True)
|
|
|
|
|
|
if ingr_features is None:
|
|
|
|
x, _ = self.cond_att(query=x,
|
|
key=img_features,
|
|
value=img_features,
|
|
key_padding_mask=None,
|
|
incremental_state=incremental_state,
|
|
static_kv=True,
|
|
)
|
|
elif img_features is None:
|
|
x, _ = self.cond_att(query=x,
|
|
key=ingr_features,
|
|
value=ingr_features,
|
|
key_padding_mask=ingr_mask,
|
|
incremental_state=incremental_state,
|
|
static_kv=True,
|
|
)
|
|
|
|
|
|
else:
|
|
|
|
kv = torch.cat((img_features, ingr_features), 0)
|
|
mask = torch.cat((torch.zeros(img_features.shape[1], img_features.shape[0], dtype=torch.uint8).to(device),
|
|
ingr_mask), 1)
|
|
x, _ = self.cond_att(query=x,
|
|
key=kv,
|
|
value=kv,
|
|
key_padding_mask=mask,
|
|
incremental_state=incremental_state,
|
|
static_kv=True,
|
|
)
|
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
|
x = residual + x
|
|
x = self.maybe_layer_norm(1, x, after=True)
|
|
|
|
residual = x
|
|
x = self.maybe_layer_norm(-1, x, before=True)
|
|
x = F.relu(self.fc1(x))
|
|
x = F.dropout(x, p=self.relu_dropout, training=self.training)
|
|
x = self.fc2(x)
|
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
|
x = residual + x
|
|
x = self.maybe_layer_norm(-1, x, after=True)
|
|
|
|
if self.use_last_ln:
|
|
x = self.last_ln(x)
|
|
|
|
return x
|
|
|
|
def maybe_layer_norm(self, i, x, before=False, after=False):
|
|
assert before ^ after
|
|
if after ^ self.normalize_before:
|
|
return self.layer_norms[i](x)
|
|
else:
|
|
return x
|
|
|
|
class DecoderTransformer(nn.Module):
|
|
"""Transformer decoder."""
|
|
|
|
def __init__(self, embed_size, vocab_size, dropout=0.5, seq_length=20, num_instrs=15,
|
|
attention_nheads=16, pos_embeddings=True, num_layers=8, learned=True, normalize_before=True,
|
|
normalize_inputs=False, last_ln=False, scale_embed_grad=False):
|
|
super(DecoderTransformer, self).__init__()
|
|
self.dropout = dropout
|
|
self.seq_length = seq_length * num_instrs
|
|
self.embed_tokens = nn.Embedding(vocab_size, embed_size, padding_idx=vocab_size-1,
|
|
scale_grad_by_freq=scale_embed_grad)
|
|
nn.init.normal_(self.embed_tokens.weight, mean=0, std=embed_size ** -0.5)
|
|
if pos_embeddings:
|
|
self.embed_positions = PositionalEmbedding(1024, embed_size, 0, left_pad=False, learned=learned)
|
|
else:
|
|
self.embed_positions = None
|
|
self.normalize_inputs = normalize_inputs
|
|
if self.normalize_inputs:
|
|
self.layer_norms_in = nn.ModuleList([LayerNorm(embed_size) for i in range(3)])
|
|
|
|
self.embed_scale = math.sqrt(embed_size)
|
|
self.layers = nn.ModuleList([])
|
|
self.layers.extend([
|
|
TransformerDecoderLayer(embed_size, attention_nheads, dropout=dropout, normalize_before=normalize_before,
|
|
last_ln=last_ln)
|
|
for i in range(num_layers)
|
|
])
|
|
|
|
self.linear = Linear(embed_size, vocab_size-1)
|
|
|
|
def forward(self, ingr_features, ingr_mask, captions, img_features, incremental_state=None):
|
|
|
|
if ingr_features is not None:
|
|
ingr_features = ingr_features.permute(0, 2, 1)
|
|
ingr_features = ingr_features.transpose(0, 1)
|
|
if self.normalize_inputs:
|
|
self.layer_norms_in[0](ingr_features)
|
|
|
|
if img_features is not None:
|
|
img_features = img_features.permute(0, 2, 1)
|
|
img_features = img_features.transpose(0, 1)
|
|
if self.normalize_inputs:
|
|
self.layer_norms_in[1](img_features)
|
|
|
|
if ingr_mask is not None:
|
|
ingr_mask = (1-ingr_mask.squeeze(1)).byte()
|
|
|
|
|
|
if self.embed_positions is not None:
|
|
positions = self.embed_positions(captions, incremental_state=incremental_state)
|
|
if incremental_state is not None:
|
|
if self.embed_positions is not None:
|
|
positions = positions[:, -1:]
|
|
captions = captions[:, -1:]
|
|
|
|
|
|
x = self.embed_scale * self.embed_tokens(captions)
|
|
|
|
if self.embed_positions is not None:
|
|
x += positions
|
|
|
|
if self.normalize_inputs:
|
|
x = self.layer_norms_in[2](x)
|
|
|
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
|
|
|
|
|
x = x.transpose(0, 1)
|
|
|
|
for p, layer in enumerate(self.layers):
|
|
x = layer(
|
|
x,
|
|
ingr_features,
|
|
ingr_mask,
|
|
incremental_state,
|
|
img_features
|
|
)
|
|
|
|
|
|
x = x.transpose(0, 1)
|
|
|
|
x = self.linear(x)
|
|
_, predicted = x.max(dim=-1)
|
|
|
|
return x, predicted
|
|
|
|
def sample(self, ingr_features, ingr_mask, greedy=True, temperature=1.0, beam=-1,
|
|
img_features=None, first_token_value=0,
|
|
replacement=True, last_token_value=0):
|
|
|
|
incremental_state = {}
|
|
|
|
|
|
if ingr_features is not None:
|
|
fs = ingr_features.size(0)
|
|
else:
|
|
fs = img_features.size(0)
|
|
|
|
if beam != -1:
|
|
if fs == 1:
|
|
return self.sample_beam(ingr_features, ingr_mask, beam, img_features, first_token_value,
|
|
replacement, last_token_value)
|
|
else:
|
|
print ("Beam Search can only be used with batch size of 1. Running greedy or temperature sampling...")
|
|
|
|
first_word = torch.ones(fs)*first_token_value
|
|
|
|
first_word = first_word.to(device).long()
|
|
sampled_ids = [first_word]
|
|
logits = []
|
|
|
|
for i in range(self.seq_length):
|
|
|
|
outputs, _ = self.forward(ingr_features, ingr_mask, torch.stack(sampled_ids, 1),
|
|
img_features, incremental_state)
|
|
outputs = outputs.squeeze(1)
|
|
if not replacement:
|
|
|
|
if i == 0:
|
|
predicted_mask = torch.zeros(outputs.shape).float().to(device)
|
|
else:
|
|
|
|
batch_ind = [j for j in range(fs) if sampled_ids[i][j] != 0]
|
|
sampled_ids_new = sampled_ids[i][batch_ind]
|
|
predicted_mask[batch_ind, sampled_ids_new] = float('-inf')
|
|
|
|
|
|
outputs += predicted_mask
|
|
|
|
logits.append(outputs)
|
|
if greedy:
|
|
outputs_prob = torch.nn.functional.softmax(outputs, dim=-1)
|
|
_, predicted = outputs_prob.max(1)
|
|
predicted = predicted.detach()
|
|
else:
|
|
k = 10
|
|
outputs_prob = torch.div(outputs.squeeze(1), temperature)
|
|
outputs_prob = torch.nn.functional.softmax(outputs_prob, dim=-1).data
|
|
|
|
|
|
prob_prev_topk, indices = torch.topk(outputs_prob, k=k, dim=1)
|
|
predicted = torch.multinomial(prob_prev_topk, 1).view(-1)
|
|
predicted = torch.index_select(indices, dim=1, index=predicted)[:, 0].detach()
|
|
|
|
sampled_ids.append(predicted)
|
|
|
|
sampled_ids = torch.stack(sampled_ids[1:], 1)
|
|
logits = torch.stack(logits, 1)
|
|
|
|
return sampled_ids, logits
|
|
|
|
def sample_beam(self, ingr_features, ingr_mask, beam=3, img_features=None, first_token_value=0,
|
|
replacement=True, last_token_value=0):
|
|
k = beam
|
|
alpha = 0.0
|
|
|
|
if ingr_features is not None:
|
|
fs = ingr_features.size(0)
|
|
else:
|
|
fs = img_features.size(0)
|
|
first_word = torch.ones(fs)*first_token_value
|
|
|
|
first_word = first_word.to(device).long()
|
|
|
|
sequences = [[[first_word], 0, {}, False, 1]]
|
|
finished = []
|
|
|
|
for i in range(self.seq_length):
|
|
|
|
all_candidates = []
|
|
for rem in range(len(sequences)):
|
|
incremental = sequences[rem][2]
|
|
outputs, _ = self.forward(ingr_features, ingr_mask, torch.stack(sequences[rem][0], 1),
|
|
img_features, incremental)
|
|
outputs = outputs.squeeze(1)
|
|
if not replacement:
|
|
|
|
if i == 0:
|
|
predicted_mask = torch.zeros(outputs.shape).float().to(device)
|
|
else:
|
|
|
|
batch_ind = [j for j in range(fs) if sequences[rem][0][i][j] != 0]
|
|
sampled_ids_new = sequences[rem][0][i][batch_ind]
|
|
predicted_mask[batch_ind, sampled_ids_new] = float('-inf')
|
|
|
|
|
|
outputs += predicted_mask
|
|
|
|
outputs_prob = torch.nn.functional.log_softmax(outputs, dim=-1)
|
|
probs, indices = torch.topk(outputs_prob, beam)
|
|
|
|
|
|
|
|
|
|
|
|
for bid in range(beam):
|
|
tokens = sequences[rem][0] + [indices[:, bid]]
|
|
score = sequences[rem][1] + probs[:, bid].squeeze().item()
|
|
if indices[:,bid].item() == last_token_value:
|
|
finished.append([tokens, score, None, True, sequences[rem][-1] + 1])
|
|
else:
|
|
all_candidates.append([tokens, score, incremental, False, sequences[rem][-1] + 1])
|
|
|
|
|
|
ordered_all = sorted(all_candidates + finished, key=lambda tup: tup[1]/(np.power(tup[-1],alpha)),
|
|
reverse=True)[:k]
|
|
if all(el[-1] == True for el in ordered_all):
|
|
all_candidates = []
|
|
|
|
|
|
ordered = sorted(all_candidates, key=lambda tup: tup[1]/(np.power(tup[-1],alpha)), reverse=True)
|
|
|
|
sequences = ordered[:k]
|
|
finished = sorted(finished, key=lambda tup: tup[1]/(np.power(tup[-1],alpha)), reverse=True)[:k]
|
|
|
|
if len(finished) != 0:
|
|
sampled_ids = torch.stack(finished[0][0][1:], 1)
|
|
logits = finished[0][1]
|
|
else:
|
|
sampled_ids = torch.stack(sequences[0][0][1:], 1)
|
|
logits = sequences[0][1]
|
|
return sampled_ids, logits
|
|
|
|
def max_positions(self):
|
|
"""Maximum output length supported by the decoder."""
|
|
return self.embed_positions.max_positions()
|
|
|
|
def upgrade_state_dict(self, state_dict):
|
|
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
|
|
if 'decoder.embed_positions.weights' in state_dict:
|
|
del state_dict['decoder.embed_positions.weights']
|
|
if 'decoder.embed_positions._float_tensor' not in state_dict:
|
|
state_dict['decoder.embed_positions._float_tensor'] = torch.FloatTensor()
|
|
return state_dict
|
|
|
|
|
|
|
|
def Embedding(num_embeddings, embedding_dim, padding_idx, ):
|
|
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
|
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
|
|
return m
|
|
|
|
|
|
def LayerNorm(embedding_dim):
|
|
m = nn.LayerNorm(embedding_dim)
|
|
return m
|
|
|
|
|
|
def Linear(in_features, out_features, bias=True):
|
|
m = nn.Linear(in_features, out_features, bias)
|
|
nn.init.xavier_uniform_(m.weight)
|
|
nn.init.constant_(m.bias, 0.)
|
|
return m
|
|
|
|
|
|
def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad, learned=False):
|
|
if learned:
|
|
m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad)
|
|
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
|
|
nn.init.constant_(m.weight[padding_idx], 0)
|
|
else:
|
|
m = SinusoidalPositionalEmbedding(embedding_dim, padding_idx, left_pad, num_embeddings)
|
|
return m
|
|
|