uie-base / convert.py
Tongjilibo's picture
init commit
2fc5f2f
'''下载预训练模型并且转了pytorch格式
'''
import argparse
import collections
import json
import os
import pickle
import torch
import logging
import shutil
from tqdm import tqdm
import time
logger = logging.Logger('log')
def get_path_from_url(url, root_dir, check_exist=True, decompress=True):
""" Download from given url to root_dir.
if file or directory specified by url is exists under
root_dir, return the path directly, otherwise download
from url and decompress it, return the path.
Args:
url (str): download url
root_dir (str): root dir for downloading, it should be
WEIGHTS_HOME or DATASET_HOME
decompress (bool): decompress zip or tar file. Default is `True`
Returns:
str: a local path to save downloaded models & weights & datasets.
"""
import os.path
import os
import tarfile
import zipfile
def is_url(path):
"""
Whether path is URL.
Args:
path (string): URL string or not.
"""
return path.startswith('http://') or path.startswith('https://')
def _map_path(url, root_dir):
# parse path after download under root_dir
fname = os.path.split(url)[-1]
fpath = fname
return os.path.join(root_dir, fpath)
def _get_download(url, fullname):
import requests
# using requests.get method
fname = os.path.basename(fullname)
try:
req = requests.get(url, stream=True)
except Exception as e: # requests.exceptions.ConnectionError
logger.info("Downloading {} from {} failed with exception {}".format(
fname, url, str(e)))
return False
if req.status_code != 200:
raise RuntimeError("Downloading from {} failed with code "
"{}!".format(url, req.status_code))
# For protecting download interupted, download to
# tmp_fullname firstly, move tmp_fullname to fullname
# after download finished
tmp_fullname = fullname + "_tmp"
total_size = req.headers.get('content-length')
with open(tmp_fullname, 'wb') as f:
if total_size:
with tqdm(total=(int(total_size) + 1023) // 1024, unit='KB') as pbar:
for chunk in req.iter_content(chunk_size=1024):
f.write(chunk)
pbar.update(1)
else:
for chunk in req.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
shutil.move(tmp_fullname, fullname)
return fullname
def _download(url, path):
"""
Download from url, save to path.
url (str): download url
path (str): download to given path
"""
if not os.path.exists(path):
os.makedirs(path)
fname = os.path.split(url)[-1]
fullname = os.path.join(path, fname)
retry_cnt = 0
logger.info("Downloading {} from {}".format(fname, url))
DOWNLOAD_RETRY_LIMIT = 3
while not os.path.exists(fullname):
if retry_cnt < DOWNLOAD_RETRY_LIMIT:
retry_cnt += 1
else:
raise RuntimeError("Download from {} failed. "
"Retry limit reached".format(url))
if not _get_download(url, fullname):
time.sleep(1)
continue
return fullname
def _uncompress_file_zip(filepath):
with zipfile.ZipFile(filepath, 'r') as files:
file_list = files.namelist()
file_dir = os.path.dirname(filepath)
if _is_a_single_file(file_list):
rootpath = file_list[0]
uncompressed_path = os.path.join(file_dir, rootpath)
files.extractall(file_dir)
elif _is_a_single_dir(file_list):
# `strip(os.sep)` to remove `os.sep` in the tail of path
rootpath = os.path.splitext(file_list[0].strip(os.sep))[0].split(
os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
files.extractall(file_dir)
else:
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
if not os.path.exists(uncompressed_path):
os.makedirs(uncompressed_path)
files.extractall(os.path.join(file_dir, rootpath))
return uncompressed_path
def _is_a_single_file(file_list):
if len(file_list) == 1 and file_list[0].find(os.sep) < 0:
return True
return False
def _is_a_single_dir(file_list):
new_file_list = []
for file_path in file_list:
if '/' in file_path:
file_path = file_path.replace('/', os.sep)
elif '\\' in file_path:
file_path = file_path.replace('\\', os.sep)
new_file_list.append(file_path)
file_name = new_file_list[0].split(os.sep)[0]
for i in range(1, len(new_file_list)):
if file_name != new_file_list[i].split(os.sep)[0]:
return False
return True
def _uncompress_file_tar(filepath, mode="r:*"):
with tarfile.open(filepath, mode) as files:
file_list = files.getnames()
file_dir = os.path.dirname(filepath)
if _is_a_single_file(file_list):
rootpath = file_list[0]
uncompressed_path = os.path.join(file_dir, rootpath)
files.extractall(file_dir)
elif _is_a_single_dir(file_list):
rootpath = os.path.splitext(file_list[0].strip(os.sep))[0].split(
os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
files.extractall(file_dir)
else:
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
if not os.path.exists(uncompressed_path):
os.makedirs(uncompressed_path)
files.extractall(os.path.join(file_dir, rootpath))
return uncompressed_path
def _decompress(fname):
"""
Decompress for zip and tar file
"""
logger.info("Decompressing {}...".format(fname))
# For protecting decompressing interupted,
# decompress to fpath_tmp directory firstly, if decompress
# successed, move decompress files to fpath and delete
# fpath_tmp and remove download compress file.
if tarfile.is_tarfile(fname):
uncompressed_path = _uncompress_file_tar(fname)
elif zipfile.is_zipfile(fname):
uncompressed_path = _uncompress_file_zip(fname)
else:
raise TypeError("Unsupport compress file type {}".format(fname))
return uncompressed_path
assert is_url(url), "downloading from {} not a url".format(url)
fullpath = _map_path(url, root_dir)
if os.path.exists(fullpath) and check_exist:
logger.info("Found {}".format(fullpath))
else:
fullpath = _download(url, root_dir)
if decompress and (tarfile.is_tarfile(fullpath) or
zipfile.is_zipfile(fullpath)):
fullpath = _decompress(fullpath)
return fullpath
MODEL_MAP = {
"uie-base": {
"resource_file_urls": {
"model_state.pdparams":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base_v0.1/model_state.pdparams",
"model_config.json":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/model_config.json",
"vocab_file":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
"special_tokens_map":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
"tokenizer_config":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json"
}
},
"uie-medium": {
"resource_file_urls": {
"model_state.pdparams":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_medium_v1.0/model_state.pdparams",
"model_config.json":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_medium/model_config.json",
"vocab_file":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
"special_tokens_map":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
"tokenizer_config":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json",
}
},
"uie-mini": {
"resource_file_urls": {
"model_state.pdparams":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_mini_v1.0/model_state.pdparams",
"model_config.json":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_mini/model_config.json",
"vocab_file":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
"special_tokens_map":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
"tokenizer_config":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json",
}
},
"uie-micro": {
"resource_file_urls": {
"model_state.pdparams":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_micro_v1.0/model_state.pdparams",
"model_config.json":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_micro/model_config.json",
"vocab_file":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
"special_tokens_map":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
"tokenizer_config":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json",
}
},
"uie-nano": {
"resource_file_urls": {
"model_state.pdparams":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_nano_v1.0/model_state.pdparams",
"model_config.json":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_nano/model_config.json",
"vocab_file":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
"special_tokens_map":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
"tokenizer_config":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json",
}
},
"uie-medical-base": {
"resource_file_urls": {
"model_state.pdparams":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_medical_base_v0.1/model_state.pdparams",
"model_config.json":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/model_config.json",
"vocab_file":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
"special_tokens_map":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
"tokenizer_config":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json",
}
},
"uie-tiny": {
"resource_file_urls": {
"model_state.pdparams":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny_v0.1/model_state.pdparams",
"model_config.json":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/model_config.json",
"vocab_file":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/vocab.txt",
"special_tokens_map":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/special_tokens_map.json",
"tokenizer_config":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/tokenizer_config.json"
}
}
}
def build_params_map(attention_num=12):
"""
build params map from paddle-paddle's ERNIE to transformer's BERT
:return:
"""
weight_map = collections.OrderedDict({
'encoder.embeddings.word_embeddings.weight': "bert.embeddings.word_embeddings.weight",
'encoder.embeddings.position_embeddings.weight': "bert.embeddings.position_embeddings.weight",
'encoder.embeddings.token_type_embeddings.weight': "bert.embeddings.token_type_embeddings.weight",
'encoder.embeddings.task_type_embeddings.weight': "embeddings.task_type_embeddings.weight", # 这里没有前缀bert,直接映射到bert4torch结构
'encoder.embeddings.layer_norm.weight': 'bert.embeddings.LayerNorm.weight',
'encoder.embeddings.layer_norm.bias': 'bert.embeddings.LayerNorm.bias',
})
# add attention layers
for i in range(attention_num):
weight_map[f'encoder.encoder.layers.{i}.self_attn.q_proj.weight'] = f'bert.encoder.layer.{i}.attention.self.query.weight'
weight_map[f'encoder.encoder.layers.{i}.self_attn.q_proj.bias'] = f'bert.encoder.layer.{i}.attention.self.query.bias'
weight_map[f'encoder.encoder.layers.{i}.self_attn.k_proj.weight'] = f'bert.encoder.layer.{i}.attention.self.key.weight'
weight_map[f'encoder.encoder.layers.{i}.self_attn.k_proj.bias'] = f'bert.encoder.layer.{i}.attention.self.key.bias'
weight_map[f'encoder.encoder.layers.{i}.self_attn.v_proj.weight'] = f'bert.encoder.layer.{i}.attention.self.value.weight'
weight_map[f'encoder.encoder.layers.{i}.self_attn.v_proj.bias'] = f'bert.encoder.layer.{i}.attention.self.value.bias'
weight_map[f'encoder.encoder.layers.{i}.self_attn.out_proj.weight'] = f'bert.encoder.layer.{i}.attention.output.dense.weight'
weight_map[f'encoder.encoder.layers.{i}.self_attn.out_proj.bias'] = f'bert.encoder.layer.{i}.attention.output.dense.bias'
weight_map[f'encoder.encoder.layers.{i}.norm1.weight'] = f'bert.encoder.layer.{i}.attention.output.LayerNorm.weight'
weight_map[f'encoder.encoder.layers.{i}.norm1.bias'] = f'bert.encoder.layer.{i}.attention.output.LayerNorm.bias'
weight_map[f'encoder.encoder.layers.{i}.linear1.weight'] = f'bert.encoder.layer.{i}.intermediate.dense.weight'
weight_map[f'encoder.encoder.layers.{i}.linear1.bias'] = f'bert.encoder.layer.{i}.intermediate.dense.bias'
weight_map[f'encoder.encoder.layers.{i}.linear2.weight'] = f'bert.encoder.layer.{i}.output.dense.weight'
weight_map[f'encoder.encoder.layers.{i}.linear2.bias'] = f'bert.encoder.layer.{i}.output.dense.bias'
weight_map[f'encoder.encoder.layers.{i}.norm2.weight'] = f'bert.encoder.layer.{i}.output.LayerNorm.weight'
weight_map[f'encoder.encoder.layers.{i}.norm2.bias'] = f'bert.encoder.layer.{i}.output.LayerNorm.bias'
# add pooler
weight_map.update(
{
'encoder.pooler.dense.weight': 'bert.pooler.dense.weight',
'encoder.pooler.dense.bias': 'bert.pooler.dense.bias',
'linear_start.weight': 'linear_start.weight',
'linear_start.bias': 'linear_start.bias',
'linear_end.weight': 'linear_end.weight',
'linear_end.bias': 'linear_end.bias',
}
)
return weight_map
def extract_and_convert(input_dir, output_dir):
if not os.path.exists(output_dir):
os.makedirs(output_dir)
logger.info('=' * 20 + 'save config file' + '=' * 20)
config = json.load(open(os.path.join(input_dir, 'model_config.json'), 'rt', encoding='utf-8'))
config = config['init_args'][0]
config["architectures"] = ["UIE"]
config['layer_norm_eps'] = 1e-12
del config['init_class']
if 'sent_type_vocab_size' in config:
config['type_vocab_size'] = config['sent_type_vocab_size']
config['intermediate_size'] = 4 * config['hidden_size']
json.dump(config, open(os.path.join(output_dir, 'config.json'),
'wt', encoding='utf-8'), indent=4)
logger.info('=' * 20 + 'save vocab file' + '=' * 20)
with open(os.path.join(input_dir, 'vocab.txt'), 'rt', encoding='utf-8') as f:
words = f.read().splitlines()
words_set = set()
words_duplicate_indices = []
for i in range(len(words)-1, -1, -1):
word = words[i]
if word in words_set:
words_duplicate_indices.append(i)
words_set.add(word)
for i, idx in enumerate(words_duplicate_indices):
words[idx] = chr(0x1F6A9+i) # Change duplicated word to 🚩 LOL
with open(os.path.join(output_dir, 'vocab.txt'), 'wt', encoding='utf-8') as f:
for word in words:
f.write(word+'\n')
special_tokens_map = {
"unk_token": "[UNK]",
"sep_token": "[SEP]",
"pad_token": "[PAD]",
"cls_token": "[CLS]",
"mask_token": "[MASK]"
}
json.dump(special_tokens_map, open(os.path.join(output_dir, 'special_tokens_map.json'),
'wt', encoding='utf-8'))
tokenizer_config = {
"do_lower_case": True,
"unk_token": "[UNK]",
"sep_token": "[SEP]",
"pad_token": "[PAD]",
"cls_token": "[CLS]",
"mask_token": "[MASK]",
"tokenizer_class": "BertTokenizer"
}
json.dump(tokenizer_config, open(os.path.join(output_dir, 'tokenizer_config.json'),
'wt', encoding='utf-8'))
logger.info('=' * 20 + 'extract weights' + '=' * 20)
state_dict = collections.OrderedDict()
weight_map = build_params_map(attention_num=config['num_hidden_layers'])
paddle_paddle_params = pickle.load(
open(os.path.join(input_dir, 'model_state.pdparams'), 'rb'))
del paddle_paddle_params['StructuredToParameterName@@']
for weight_name, weight_value in paddle_paddle_params.items():
if 'weight' in weight_name:
if 'encoder.encoder' in weight_name or 'pooler' in weight_name or 'linear' in weight_name:
weight_value = weight_value.transpose()
# Fix: embedding error
if 'word_embeddings.weight' in weight_name:
weight_value[0, :] = 0
if weight_name not in weight_map:
logger.info(f"{'='*20} [SKIP] {weight_name} {'='*20}")
continue
state_dict[weight_map[weight_name]] = torch.FloatTensor(weight_value)
logger.info(f"{weight_name} -> {weight_map[weight_name]} {weight_value.shape}")
torch.save(state_dict, os.path.join(output_dir, "pytorch_model.bin"))
def check_model(input_model):
if not os.path.exists(input_model):
if input_model not in MODEL_MAP:
raise ValueError('input_model not exists!')
resource_file_urls = MODEL_MAP[input_model]['resource_file_urls']
logger.info("Downloading resource files...")
for key, val in resource_file_urls.items():
file_path = os.path.join(input_model, key)
if not os.path.exists(file_path):
get_path_from_url(val, input_model)
def do_main():
check_model(args.input_model)
extract_and_convert(args.input_model, args.output_model)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input_model", default="uie-base", type=str,
help="Directory of input paddle model.\n Will auto download model [uie-base/uie-tiny]")
parser.add_argument("-o", "--output_model", default="uie_base_pytorch", type=str,
help="Directory of output pytorch model")
args = parser.parse_args()
do_main()