|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import Parameter
|
|
import torch.nn.functional as F
|
|
|
|
from utils import fill_with_neg_inf, get_incremental_state, set_incremental_state
|
|
|
|
|
|
class MultiheadAttention(nn.Module):
|
|
"""Multi-headed attention.
|
|
See "Attention Is All You Need" for more details.
|
|
"""
|
|
def __init__(self, embed_dim, num_heads, dropout=0., bias=True):
|
|
super().__init__()
|
|
self.embed_dim = embed_dim
|
|
self.num_heads = num_heads
|
|
self.dropout = dropout
|
|
self.head_dim = embed_dim // num_heads
|
|
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
|
self.scaling = self.head_dim**-0.5
|
|
self._mask = None
|
|
|
|
self.in_proj_weight = Parameter(torch.Tensor(3*embed_dim, embed_dim))
|
|
if bias:
|
|
self.in_proj_bias = Parameter(torch.Tensor(3*embed_dim))
|
|
else:
|
|
self.register_parameter('in_proj_bias', None)
|
|
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
nn.init.xavier_uniform_(self.in_proj_weight)
|
|
nn.init.xavier_uniform_(self.out_proj.weight)
|
|
if self.in_proj_bias is not None:
|
|
nn.init.constant_(self.in_proj_bias, 0.)
|
|
nn.init.constant_(self.out_proj.bias, 0.)
|
|
|
|
def forward(self, query, key, value, mask_future_timesteps=False,
|
|
key_padding_mask=None, incremental_state=None,
|
|
need_weights=True, static_kv=False):
|
|
"""Input shape: Time x Batch x Channel
|
|
Self-attention can be implemented by passing in the same arguments for
|
|
query, key and value. Future timesteps can be masked with the
|
|
`mask_future_timesteps` argument. Padding elements can be excluded from
|
|
the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
|
|
batch x src_len, where padding elements are indicated by 1s.
|
|
"""
|
|
|
|
qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr()
|
|
kv_same = key.data_ptr() == value.data_ptr()
|
|
|
|
tgt_len, bsz, embed_dim = query.size()
|
|
assert embed_dim == self.embed_dim
|
|
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
|
assert key.size() == value.size()
|
|
|
|
if incremental_state is not None:
|
|
saved_state = self._get_input_buffer(incremental_state)
|
|
if 'prev_key' in saved_state:
|
|
|
|
|
|
if static_kv:
|
|
assert kv_same and not qkv_same
|
|
key = value = None
|
|
else:
|
|
saved_state = None
|
|
|
|
if qkv_same:
|
|
|
|
q, k, v = self.in_proj_qkv(query)
|
|
elif kv_same:
|
|
|
|
q = self.in_proj_q(query)
|
|
if key is None:
|
|
assert value is None
|
|
|
|
|
|
k = v = q.new(0)
|
|
else:
|
|
k, v = self.in_proj_kv(key)
|
|
else:
|
|
q = self.in_proj_q(query)
|
|
k = self.in_proj_k(key)
|
|
v = self.in_proj_v(value)
|
|
q *= self.scaling
|
|
|
|
if saved_state is not None:
|
|
if 'prev_key' in saved_state:
|
|
k = torch.cat((saved_state['prev_key'], k), dim=0)
|
|
if 'prev_value' in saved_state:
|
|
v = torch.cat((saved_state['prev_value'], v), dim=0)
|
|
saved_state['prev_key'] = k
|
|
saved_state['prev_value'] = v
|
|
self._set_input_buffer(incremental_state, saved_state)
|
|
|
|
src_len = k.size(0)
|
|
|
|
if key_padding_mask is not None:
|
|
assert key_padding_mask.size(0) == bsz
|
|
assert key_padding_mask.size(1) == src_len
|
|
|
|
q = q.contiguous().view(tgt_len, bsz*self.num_heads, self.head_dim).transpose(0, 1)
|
|
k = k.contiguous().view(src_len, bsz*self.num_heads, self.head_dim).transpose(0, 1)
|
|
v = v.contiguous().view(src_len, bsz*self.num_heads, self.head_dim).transpose(0, 1)
|
|
|
|
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
|
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
|
|
|
|
|
if mask_future_timesteps and incremental_state is None:
|
|
assert query.size() == key.size(), \
|
|
'mask_future_timesteps only applies to self-attention'
|
|
attn_weights += self.buffered_mask(attn_weights).unsqueeze(0)
|
|
if key_padding_mask is not None:
|
|
|
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
|
attn_weights = attn_weights.float().masked_fill(
|
|
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
|
float('-inf'),
|
|
).type_as(attn_weights)
|
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
|
|
|
attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(attn_weights)
|
|
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)
|
|
|
|
attn = torch.bmm(attn_weights, v)
|
|
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
|
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
|
attn = self.out_proj(attn)
|
|
|
|
|
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
|
attn_weights = attn_weights.sum(dim=1) / self.num_heads
|
|
|
|
return attn, attn_weights
|
|
|
|
def in_proj_qkv(self, query):
|
|
return self._in_proj(query).chunk(3, dim=-1)
|
|
|
|
def in_proj_kv(self, key):
|
|
return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1)
|
|
|
|
def in_proj_q(self, query):
|
|
return self._in_proj(query, end=self.embed_dim)
|
|
|
|
def in_proj_k(self, key):
|
|
return self._in_proj(key, start=self.embed_dim, end=2*self.embed_dim)
|
|
|
|
def in_proj_v(self, value):
|
|
return self._in_proj(value, start=2*self.embed_dim)
|
|
|
|
def _in_proj(self, input, start=None, end=None):
|
|
weight = self.in_proj_weight
|
|
bias = self.in_proj_bias
|
|
if end is not None:
|
|
weight = weight[:end, :]
|
|
if bias is not None:
|
|
bias = bias[:end]
|
|
if start is not None:
|
|
weight = weight[start:, :]
|
|
if bias is not None:
|
|
bias = bias[start:]
|
|
return F.linear(input, weight, bias)
|
|
|
|
def buffered_mask(self, tensor):
|
|
dim = tensor.size(-1)
|
|
if self._mask is None:
|
|
self._mask = torch.triu(fill_with_neg_inf(tensor.new(dim, dim)), 1)
|
|
if self._mask.size(0) < dim:
|
|
self._mask = torch.triu(fill_with_neg_inf(self._mask.resize_(dim, dim)), 1)
|
|
return self._mask[:dim, :dim]
|
|
|
|
def reorder_incremental_state(self, incremental_state, new_order):
|
|
"""Reorder buffered internal state (for incremental generation)."""
|
|
input_buffer = self._get_input_buffer(incremental_state)
|
|
if input_buffer is not None:
|
|
for k in input_buffer.keys():
|
|
input_buffer[k] = input_buffer[k].index_select(1, new_order)
|
|
self._set_input_buffer(incremental_state, input_buffer)
|
|
|
|
def _get_input_buffer(self, incremental_state):
|
|
return get_incremental_state(
|
|
self,
|
|
incremental_state,
|
|
'attn_state',
|
|
) or {}
|
|
|
|
def _set_input_buffer(self, incremental_state, buffer):
|
|
set_incremental_state(
|
|
self,
|
|
incremental_state,
|
|
'attn_state',
|
|
buffer,
|
|
)
|
|
|