|
from tokenizers import pre_tokenizers, Tokenizer |
|
from tokenizers.models import BPE |
|
from tokenizers.trainers import BpeTrainer |
|
from tokenizers.pre_tokenizers import Whitespace, PreTokenizer |
|
import random |
|
import os |
|
from tqdm import tqdm |
|
import itertools as it |
|
|
|
class ImmPreTokenizer: |
|
def pre_tokenize(self, pretok): |
|
pretok.split(self.hex_imm_split) |
|
def hex_imm_split(self, i, norm_str): |
|
tok = str(norm_str) |
|
if tok[:2] == "0x" or tok.isdigit(): |
|
return [norm_str[i:i+1] for i in range(len(tok))] |
|
else: |
|
return [norm_str] |
|
|
|
def get_asm_tok(files, save): |
|
asm_tok = Tokenizer(BPE(unk_token="@@UNK@@")) |
|
asm_tok.pre_tokenizer = pre_tokenizers.Sequence([Whitespace(), PreTokenizer.custom(ImmPreTokenizer())]) |
|
asm_train = BpeTrainer(special_tokens=["@@UNK@@"]) |
|
|
|
asm_tok.train(files, asm_train) |
|
asm_tok.pre_tokenizer = Whitespace() |
|
asm_tok.save(save) |
|
asm_tok.pre_tokenizer = pre_tokenizers.Sequence([Whitespace(), PreTokenizer.custom(ImmPreTokenizer())]) |
|
|
|
return asm_tok |
|
|
|
def load_asm_tok(load): |
|
asm_tok = Tokenizer.from_file(load) |
|
asm_tok.pre_tokenizer = pre_tokenizers.Sequence([Whitespace(), PreTokenizer.custom(ImmPreTokenizer())]) |
|
return asm_tok |
|
|
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
parser = argparse.ArgumentParser("Train the tokenizer and tokenize the asm") |
|
parser.add_argument("-i", "--indir", required=True, help="output directory") |
|
parser.add_argument("-o", "--outdir", default="tokenized", help="output directory") |
|
args = parser.parse_args() |
|
|
|
os.makedirs(args.outdir, exist_ok=True) |
|
injoin = lambda p: os.path.join(args.indir, p) |
|
pjoin = lambda p: os.path.join(args.outdir, p) |
|
max_asm_toks = 0 |
|
|
|
asm_tok = get_asm_tok([injoin("train.asm"), injoin("valid.asm")], pjoin("asm_tokens.json")) |
|
for split in ["train", "valid", "test"]: |
|
asmfile = split + ".asm" |
|
with open(injoin(asmfile), "r") as asmf, open(pjoin(asmfile), "w") as asmtokf: |
|
for asm in tqdm(asmf, desc=f"Tokenizing {split}"): |
|
asm = asm.strip() |
|
asm_enc = asm_tok.encode(asm) |
|
max_asm_toks = max(max_asm_toks, len(asm_enc.tokens)) |
|
asm_seq = " ".join(asm_enc.tokens) |
|
asmtokf.write(asm_seq + "\n") |
|
|
|
print("Maximum tokens:", max_asm_toks) |
|
|
|
|
|
|
|
|