|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from collections import defaultdict, OrderedDict
|
|
import logging
|
|
import os
|
|
import re
|
|
import torch
|
|
import traceback
|
|
|
|
from torch.serialization import default_restore_location
|
|
|
|
|
|
def torch_persistent_save(*args, **kwargs):
|
|
for i in range(3):
|
|
try:
|
|
return torch.save(*args, **kwargs)
|
|
except Exception:
|
|
if i == 2:
|
|
logging.error(traceback.format_exc())
|
|
|
|
|
|
def convert_state_dict_type(state_dict, ttype=torch.FloatTensor):
|
|
if isinstance(state_dict, dict):
|
|
cpu_dict = OrderedDict()
|
|
for k, v in state_dict.items():
|
|
cpu_dict[k] = convert_state_dict_type(v)
|
|
return cpu_dict
|
|
elif isinstance(state_dict, list):
|
|
return [convert_state_dict_type(v) for v in state_dict]
|
|
elif torch.is_tensor(state_dict):
|
|
return state_dict.type(ttype)
|
|
else:
|
|
return state_dict
|
|
|
|
|
|
def save_state(filename, args, model, criterion, optimizer, lr_scheduler,
|
|
num_updates, optim_history=None, extra_state=None):
|
|
if optim_history is None:
|
|
optim_history = []
|
|
if extra_state is None:
|
|
extra_state = {}
|
|
state_dict = {
|
|
'args': args,
|
|
'model': convert_state_dict_type(model.state_dict()),
|
|
'optimizer_history': optim_history + [
|
|
{
|
|
'criterion_name': criterion.__class__.__name__,
|
|
'optimizer_name': optimizer.__class__.__name__,
|
|
'lr_scheduler_state': lr_scheduler.state_dict(),
|
|
'num_updates': num_updates,
|
|
}
|
|
],
|
|
'last_optimizer_state': convert_state_dict_type(optimizer.state_dict()),
|
|
'extra_state': extra_state,
|
|
}
|
|
torch_persistent_save(state_dict, filename)
|
|
|
|
|
|
def load_model_state(filename, model):
|
|
if not os.path.exists(filename):
|
|
return None, [], None
|
|
state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
|
|
state = _upgrade_state_dict(state)
|
|
model.upgrade_state_dict(state['model'])
|
|
|
|
|
|
try:
|
|
model.load_state_dict(state['model'], strict=True)
|
|
except Exception:
|
|
raise Exception('Cannot load model parameters from checkpoint, '
|
|
'please ensure that the architectures match')
|
|
|
|
return state['extra_state'], state['optimizer_history'], state['last_optimizer_state']
|
|
|
|
|
|
def _upgrade_state_dict(state):
|
|
"""Helper for upgrading old model checkpoints."""
|
|
|
|
if 'optimizer_history' not in state:
|
|
state['optimizer_history'] = [
|
|
{
|
|
'criterion_name': 'CrossEntropyCriterion',
|
|
'best_loss': state['best_loss'],
|
|
},
|
|
]
|
|
state['last_optimizer_state'] = state['optimizer']
|
|
del state['optimizer']
|
|
del state['best_loss']
|
|
|
|
if 'epoch' in state and 'extra_state' not in state:
|
|
state['extra_state'] = {
|
|
'epoch': state['epoch'],
|
|
'batch_offset': state['batch_offset'],
|
|
'val_loss': state['val_loss'],
|
|
}
|
|
del state['epoch']
|
|
del state['batch_offset']
|
|
del state['val_loss']
|
|
|
|
if 'optimizer' in state['optimizer_history'][-1]:
|
|
state['last_optimizer_state'] = state['optimizer_history'][-1]['optimizer']
|
|
for optim_hist in state['optimizer_history']:
|
|
del optim_hist['optimizer']
|
|
|
|
if 'optimizer_name' not in state['optimizer_history'][-1]:
|
|
state['optimizer_history'][-1]['optimizer_name'] = 'FairseqNAG'
|
|
|
|
if 'lr_scheduler_state' not in state['optimizer_history'][-1]:
|
|
state['optimizer_history'][-1]['lr_scheduler_state'] = {
|
|
'best': state['optimizer_history'][-1]['best_loss'],
|
|
}
|
|
del state['optimizer_history'][-1]['best_loss']
|
|
|
|
if 'num_updates' not in state['optimizer_history'][-1]:
|
|
state['optimizer_history'][-1]['num_updates'] = 0
|
|
|
|
if hasattr(state['args'], 'max_positions') and not hasattr(state['args'], 'max_source_positions'):
|
|
state['args'].max_source_positions = state['args'].max_positions
|
|
state['args'].max_target_positions = state['args'].max_positions
|
|
|
|
if 'train_iterator' not in state['extra_state']:
|
|
state['extra_state']['train_iterator'] = {
|
|
'epoch': state['extra_state']['epoch'],
|
|
'iterations_in_epoch': 0,
|
|
}
|
|
return state
|
|
|
|
|
|
def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
|
|
"""Load an ensemble of models for inference.
|
|
model_arg_overrides allows you to pass a dictionary model_arg_overrides --
|
|
{'arg_name': arg} -- to override model args that were used during model
|
|
training
|
|
"""
|
|
|
|
states = []
|
|
for filename in filenames:
|
|
if not os.path.exists(filename):
|
|
raise IOError('Model file not found: {}'.format(filename))
|
|
state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
|
|
state = _upgrade_state_dict(state)
|
|
states.append(state)
|
|
args = states[0]['args']
|
|
if model_arg_overrides is not None:
|
|
args = _override_model_args(args, model_arg_overrides)
|
|
|
|
|
|
ensemble = []
|
|
for state in states:
|
|
model = task.build_model(args)
|
|
model.upgrade_state_dict(state['model'])
|
|
model.load_state_dict(state['model'], strict=True)
|
|
ensemble.append(model)
|
|
return ensemble, args
|
|
|
|
|
|
def _override_model_args(args, model_arg_overrides):
|
|
|
|
for arg_name, arg_val in model_arg_overrides.items():
|
|
setattr(args, arg_name, arg_val)
|
|
return args
|
|
|
|
|
|
def move_to_cuda(sample):
|
|
if len(sample) == 0:
|
|
return {}
|
|
|
|
def _move_to_cuda(maybe_tensor):
|
|
if torch.is_tensor(maybe_tensor):
|
|
return maybe_tensor.cuda()
|
|
elif isinstance(maybe_tensor, dict):
|
|
return {
|
|
key: _move_to_cuda(value)
|
|
for key, value in maybe_tensor.items()
|
|
}
|
|
elif isinstance(maybe_tensor, list):
|
|
return [_move_to_cuda(x) for x in maybe_tensor]
|
|
else:
|
|
return maybe_tensor
|
|
|
|
return _move_to_cuda(sample)
|
|
|
|
|
|
INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0)
|
|
|
|
|
|
def _get_full_incremental_state_key(module_instance, key):
|
|
module_name = module_instance.__class__.__name__
|
|
|
|
|
|
|
|
if not hasattr(module_instance, '_fairseq_instance_id'):
|
|
INCREMENTAL_STATE_INSTANCE_ID[module_name] += 1
|
|
module_instance._fairseq_instance_id = INCREMENTAL_STATE_INSTANCE_ID[module_name]
|
|
|
|
return '{}.{}.{}'.format(module_name, module_instance._fairseq_instance_id, key)
|
|
|
|
|
|
def get_incremental_state(module, incremental_state, key):
|
|
"""Helper for getting incremental state for an nn.Module."""
|
|
full_key = _get_full_incremental_state_key(module, key)
|
|
if incremental_state is None or full_key not in incremental_state:
|
|
return None
|
|
return incremental_state[full_key]
|
|
|
|
|
|
def set_incremental_state(module, incremental_state, key, value):
|
|
"""Helper for setting incremental state for an nn.Module."""
|
|
if incremental_state is not None:
|
|
full_key = _get_full_incremental_state_key(module, key)
|
|
incremental_state[full_key] = value
|
|
|
|
|
|
def load_align_dict(replace_unk):
|
|
if replace_unk is None:
|
|
align_dict = None
|
|
elif isinstance(replace_unk, str):
|
|
|
|
align_dict = {}
|
|
with open(replace_unk, 'r') as f:
|
|
for line in f:
|
|
cols = line.split()
|
|
align_dict[cols[0]] = cols[1]
|
|
else:
|
|
|
|
|
|
align_dict = {}
|
|
return align_dict
|
|
|
|
|
|
def print_embed_overlap(embed_dict, vocab_dict):
|
|
embed_keys = set(embed_dict.keys())
|
|
vocab_keys = set(vocab_dict.symbols)
|
|
overlap = len(embed_keys & vocab_keys)
|
|
print("| Found {}/{} types in embedding file.".format(overlap, len(vocab_dict)))
|
|
|
|
|
|
def parse_embedding(embed_path):
|
|
"""Parse embedding text file into a dictionary of word and embedding tensors.
|
|
The first line can have vocabulary size and dimension. The following lines
|
|
should contain word and embedding separated by spaces.
|
|
Example:
|
|
2 5
|
|
the -0.0230 -0.0264 0.0287 0.0171 0.1403
|
|
at -0.0395 -0.1286 0.0275 0.0254 -0.0932
|
|
"""
|
|
embed_dict = {}
|
|
with open(embed_path) as f_embed:
|
|
next(f_embed)
|
|
for line in f_embed:
|
|
pieces = line.rstrip().split(" ")
|
|
embed_dict[pieces[0]] = torch.Tensor([float(weight) for weight in pieces[1:]])
|
|
return embed_dict
|
|
|
|
|
|
def load_embedding(embed_dict, vocab, embedding):
|
|
for idx in range(len(vocab)):
|
|
token = vocab[idx]
|
|
if token in embed_dict:
|
|
embedding.weight.data[idx] = embed_dict[token]
|
|
return embedding
|
|
|
|
|
|
def replace_unk(hypo_str, src_str, alignment, align_dict, unk):
|
|
from fairseq import tokenizer
|
|
|
|
hypo_tokens = tokenizer.tokenize_line(hypo_str)
|
|
|
|
src_tokens = tokenizer.tokenize_line(src_str) + ['<eos>']
|
|
for i, ht in enumerate(hypo_tokens):
|
|
if ht == unk:
|
|
src_token = src_tokens[alignment[i]]
|
|
|
|
hypo_tokens[i] = align_dict.get(src_token, src_token)
|
|
return ' '.join(hypo_tokens)
|
|
|
|
|
|
def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, tgt_dict, remove_bpe):
|
|
from fairseq import tokenizer
|
|
hypo_str = tgt_dict.string(hypo_tokens, remove_bpe)
|
|
if align_dict is not None:
|
|
hypo_str = replace_unk(hypo_str, src_str, alignment, align_dict, tgt_dict.unk_string())
|
|
if align_dict is not None or remove_bpe is not None:
|
|
|
|
|
|
hypo_tokens = tokenizer.Tokenizer.tokenize(hypo_str, tgt_dict, add_if_not_exist=True)
|
|
return hypo_tokens, hypo_str, alignment
|
|
|
|
|
|
def make_positions(tensor, padding_idx, left_pad):
|
|
"""Replace non-padding symbols with their position numbers.
|
|
Position numbers begin at padding_idx+1.
|
|
Padding symbols are ignored, but it is necessary to specify whether padding
|
|
is added on the left side (left_pad=True) or right side (left_pad=False).
|
|
"""
|
|
max_pos = padding_idx + 1 + tensor.size(1)
|
|
if not hasattr(make_positions, 'range_buf'):
|
|
make_positions.range_buf = tensor.new()
|
|
make_positions.range_buf = make_positions.range_buf.type_as(tensor)
|
|
if make_positions.range_buf.numel() < max_pos:
|
|
torch.arange(padding_idx + 1, max_pos, out=make_positions.range_buf)
|
|
mask = tensor.ne(padding_idx)
|
|
positions = make_positions.range_buf[:tensor.size(1)].expand_as(tensor)
|
|
if left_pad:
|
|
positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1)
|
|
return tensor.clone().masked_scatter_(mask, positions[mask])
|
|
|
|
|
|
def strip_pad(tensor, pad):
|
|
return tensor[tensor.ne(pad)]
|
|
|
|
|
|
def buffered_arange(max):
|
|
if not hasattr(buffered_arange, 'buf'):
|
|
buffered_arange.buf = torch.LongTensor()
|
|
if max > buffered_arange.buf.numel():
|
|
torch.arange(max, out=buffered_arange.buf)
|
|
return buffered_arange.buf[:max]
|
|
|
|
|
|
def convert_padding_direction(src_tokens, padding_idx, right_to_left=False, left_to_right=False):
|
|
assert right_to_left ^ left_to_right
|
|
pad_mask = src_tokens.eq(padding_idx)
|
|
if not pad_mask.any():
|
|
|
|
return src_tokens
|
|
if left_to_right and not pad_mask[:, 0].any():
|
|
|
|
return src_tokens
|
|
if right_to_left and not pad_mask[:, -1].any():
|
|
|
|
return src_tokens
|
|
max_len = src_tokens.size(1)
|
|
range = buffered_arange(max_len).type_as(src_tokens).expand_as(src_tokens)
|
|
num_pads = pad_mask.long().sum(dim=1, keepdim=True)
|
|
if right_to_left:
|
|
index = torch.remainder(range - num_pads, max_len)
|
|
else:
|
|
index = torch.remainder(range + num_pads, max_len)
|
|
return src_tokens.gather(1, index)
|
|
|
|
|
|
def item(tensor):
|
|
if hasattr(tensor, 'item'):
|
|
return tensor.item()
|
|
if hasattr(tensor, '__getitem__'):
|
|
return tensor[0]
|
|
return tensor
|
|
|
|
|
|
def clip_grad_norm_(tensor, max_norm):
|
|
grad_norm = item(torch.norm(tensor))
|
|
if grad_norm > max_norm > 0:
|
|
clip_coef = max_norm / (grad_norm + 1e-6)
|
|
tensor.mul_(clip_coef)
|
|
return grad_norm
|
|
|
|
|
|
def fill_with_neg_inf(t):
|
|
"""FP16-compatible function that fills a tensor with -inf."""
|
|
return t.float().fill_(float('-inf')).type_as(t)
|
|
|
|
|
|
def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
|
|
"""Retrieves all checkpoints found in `path` directory.
|
|
Checkpoints are identified by matching filename to the specified pattern. If
|
|
the pattern contains groups, the result will be sorted by the first group in
|
|
descending order.
|
|
"""
|
|
pt_regexp = re.compile(pattern)
|
|
files = os.listdir(path)
|
|
|
|
entries = []
|
|
for i, f in enumerate(files):
|
|
m = pt_regexp.fullmatch(f)
|
|
if m is not None:
|
|
idx = int(m.group(1)) if len(m.groups()) > 0 else i
|
|
entries.append((idx, m.group(0)))
|
|
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
|
|
|