udiboy1209's picture
Add real world dataset
e78b7eb
raw
history blame
2.52 kB
import json
import torch
from tqdm import tqdm
from transformers import AutoTokenizer
from modeling_nova import NovaTokenizer, NovaForCausalLM
import time
tokenizer = AutoTokenizer.from_pretrained('lt-asset/nova-6.7b-bcr', trust_remote_code=True)
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
print('Vocabulary:', len(tokenizer.get_vocab())) # 32280
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
nova_tokenizer = NovaTokenizer(tokenizer)
model = NovaForCausalLM.from_pretrained('lt-asset/nova-6.7b-bcr').eval()
with open("../slade/slade_cps_dataset.json", "r") as f:
dataset = json.load(f)
CFI_START=".cfi_startproc"
CFI_END=".cfi_endproc"
def process_asm(asm):
if CFI_START in asm:
s = asm.index(CFI_START) + len(CFI_START)
else:
s = 0
if CFI_END in asm:
e = asm.index(CFI_END)
else:
e = 0
prompt_before = f'# This is the assembly code:\n<func0>:\n'
prompt_after = '\nWhat is the source code?\n'
lines = filter(lambda s: (".cfi_" not in s) and (s.strip() != ""), asm[s:e].split("\n\t"))
asm = "\n".join(f"{l}\t<label-{i+1}>" for i, l in enumerate(lines))
char_types = ("0" * len(prompt_before)) + ("1" * len(asm)) + ("0" * len(prompt_after))
return prompt_before + asm + prompt_after, char_types
asms = [(prog, func, asm) for prog in dataset for func, asm in dataset[prog].items() if "_x64_" in prog]
with open("nova_predictions.txt", "w") as predf:
for prog, func, asm in tqdm(asms):
start = time.time()
inputs, char_types = process_asm(asm)
toks = nova_tokenizer.encode(inputs, "", char_types)
input_ids = torch.LongTensor(toks['input_ids'].tolist()).unsqueeze(0)
nova_attention_mask = torch.LongTensor(toks['nova_attention_mask']).unsqueeze(0)
no_mask_id = torch.LongTensor([toks['no_mask_idx']])
outputs = model.generate(
inputs=input_ids, max_new_tokens=512, temperature=0.2, top_p=0.95, num_return_sequences=3,
do_sample=True, nova_attention_mask=nova_attention_mask, no_mask_idx=no_mask_id,
pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id
)
end = time.time()
predf.write(f"{prog} {func} time= {end-start}\n")
for output in outputs:
outc = tokenizer.decode(output[input_ids.size(1):], skip_special_tokens=True, clean_up_tokenization_spaces=True)
predf.write(f"\t{outc}\n")