REMEND / remend /bpe.py
udiboy1209's picture
Add REMEND python module
7145fd6
from tokenizers import pre_tokenizers, Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace, PreTokenizer
import random
import os
from tqdm import tqdm
import itertools as it
class ImmPreTokenizer:
def pre_tokenize(self, pretok):
pretok.split(self.hex_imm_split)
def hex_imm_split(self, i, norm_str):
tok = str(norm_str)
if tok[:2] == "0x" or tok.isdigit():
return [norm_str[i:i+1] for i in range(len(tok))]
else:
return [norm_str]
def get_asm_tok(files, save):
asm_tok = Tokenizer(BPE(unk_token="@@UNK@@"))
asm_tok.pre_tokenizer = pre_tokenizers.Sequence([Whitespace(), PreTokenizer.custom(ImmPreTokenizer())])
asm_train = BpeTrainer(special_tokens=["@@UNK@@"])
asm_tok.train(files, asm_train)
asm_tok.pre_tokenizer = Whitespace() # Hack to save, careful to restore ImmPreTokenizer
asm_tok.save(save)
asm_tok.pre_tokenizer = pre_tokenizers.Sequence([Whitespace(), PreTokenizer.custom(ImmPreTokenizer())])
return asm_tok
def load_asm_tok(load):
asm_tok = Tokenizer.from_file(load)
asm_tok.pre_tokenizer = pre_tokenizers.Sequence([Whitespace(), PreTokenizer.custom(ImmPreTokenizer())])
return asm_tok
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser("Train the tokenizer and tokenize the asm")
parser.add_argument("-i", "--indir", required=True, help="output directory")
parser.add_argument("-o", "--outdir", default="tokenized", help="output directory")
args = parser.parse_args()
os.makedirs(args.outdir, exist_ok=True)
injoin = lambda p: os.path.join(args.indir, p)
pjoin = lambda p: os.path.join(args.outdir, p)
max_asm_toks = 0
asm_tok = get_asm_tok([injoin("train.asm"), injoin("valid.asm")], pjoin("asm_tokens.json"))
for split in ["train", "valid", "test"]:
asmfile = split + ".asm"
with open(injoin(asmfile), "r") as asmf, open(pjoin(asmfile), "w") as asmtokf:
for asm in tqdm(asmf, desc=f"Tokenizing {split}"):
asm = asm.strip()
asm_enc = asm_tok.encode(asm)
max_asm_toks = max(max_asm_toks, len(asm_enc.tokens))
asm_seq = " ".join(asm_enc.tokens)
asmtokf.write(asm_seq + "\n")
print("Maximum tokens:", max_asm_toks)
# After this, run command:
# fairseq-preprocess -s asm -t eqn --trainpref {OUTDIR}/train --validpref {OUTDIR}/valid --testpref {OUTDIR}/test --destdir {OUTDIR}