from tokenizers import pre_tokenizers, Tokenizer from tokenizers.models import BPE from tokenizers.trainers import BpeTrainer from tokenizers.pre_tokenizers import Whitespace, PreTokenizer import random import os from tqdm import tqdm import itertools as it class ImmPreTokenizer: def pre_tokenize(self, pretok): pretok.split(self.hex_imm_split) def hex_imm_split(self, i, norm_str): tok = str(norm_str) if tok[:2] == "0x" or tok.isdigit(): return [norm_str[i:i+1] for i in range(len(tok))] else: return [norm_str] def get_asm_tok(files, save): asm_tok = Tokenizer(BPE(unk_token="@@UNK@@")) asm_tok.pre_tokenizer = pre_tokenizers.Sequence([Whitespace(), PreTokenizer.custom(ImmPreTokenizer())]) asm_train = BpeTrainer(special_tokens=["@@UNK@@"]) asm_tok.train(files, asm_train) asm_tok.pre_tokenizer = Whitespace() # Hack to save, careful to restore ImmPreTokenizer asm_tok.save(save) asm_tok.pre_tokenizer = pre_tokenizers.Sequence([Whitespace(), PreTokenizer.custom(ImmPreTokenizer())]) return asm_tok def load_asm_tok(load): asm_tok = Tokenizer.from_file(load) asm_tok.pre_tokenizer = pre_tokenizers.Sequence([Whitespace(), PreTokenizer.custom(ImmPreTokenizer())]) return asm_tok if __name__ == "__main__": import argparse parser = argparse.ArgumentParser("Train the tokenizer and tokenize the asm") parser.add_argument("-i", "--indir", required=True, help="output directory") parser.add_argument("-o", "--outdir", default="tokenized", help="output directory") args = parser.parse_args() os.makedirs(args.outdir, exist_ok=True) injoin = lambda p: os.path.join(args.indir, p) pjoin = lambda p: os.path.join(args.outdir, p) max_asm_toks = 0 asm_tok = get_asm_tok([injoin("train.asm"), injoin("valid.asm")], pjoin("asm_tokens.json")) for split in ["train", "valid", "test"]: asmfile = split + ".asm" with open(injoin(asmfile), "r") as asmf, open(pjoin(asmfile), "w") as asmtokf: for asm in tqdm(asmf, desc=f"Tokenizing {split}"): asm = asm.strip() asm_enc = asm_tok.encode(asm) max_asm_toks = max(max_asm_toks, len(asm_enc.tokens)) asm_seq = " ".join(asm_enc.tokens) asmtokf.write(asm_seq + "\n") print("Maximum tokens:", max_asm_toks) # After this, run command: # fairseq-preprocess -s asm -t eqn --trainpref {OUTDIR}/train --validpref {OUTDIR}/valid --testpref {OUTDIR}/test --destdir {OUTDIR}