REMEND / remend /check_generated.py
udiboy1209's picture
Add REMEND python module
7145fd6
raw
history blame
4.84 kB
import sympy as sp
import sys
import re
from tqdm import tqdm
from Levenshtein import distance
import networkx as nx
from networkx import graph_edit_distance
from .parser import parse_prefix_to_sympy, parse_postfix_to_sympy, isint
def percent(a, n):
return f"{a/n*100:0.1f}%"
def do_simplify_match(orig_expr, gen_expr):
orig_simp = sp.simplify(orig_expr)
gen_simp = sp.simplify(gen_expr)
if orig_simp == gen_simp:
return True
return False
def do_structure_match(orig_toks, gen_toks):
def _isconst(t):
return re.match(r"c[0-9]+", t)
def _isvar(t):
return re.match(r"x[0-9]+", t)
if len(orig_toks) != len(gen_toks):
return False
for orig, gen in zip(orig_toks, gen_toks):
if (_isconst(orig) and _isconst(gen)) \
or (_isvar(orig) and _isvar(gen)) \
or (isint(orig) and isint(gen)) \
or (orig.startswith("INT") and gen.startswith("INT")) \
or (orig == gen):
continue
# Mismatched
return False
return True
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser("Check generated expressions")
parser.add_argument("-g", required=True, help="Generated expressions file")
parser.add_argument("-r", required=True, help="Results file")
parser.add_argument("--simplify", action="store_true", default=False)
parser.add_argument("--postfix", action="store_true", default=False)
args = parser.parse_args()
orig_list = []
gen_list = []
with open(args.g, 'r') as f:
for line in tqdm(f, desc="Reading file"):
comps = line.strip().split("\t")
if line[0] == 'T':
num = int(comps[0][2:])
tokens = comps[1].split(" ")
orig_list.append((num, tokens))
elif line[0] == 'H':
num = int(comps[0][2:])
tokens = comps[2].split(" ")
gen_list.append((num, tokens))
N = len(orig_list)
gen_errors = []
parsed = []
exact_match = []
structure_match = []
simplify_match = []
orig_exprs = {}
gen_exprs = {}
all_aed = []
# all_ged = []
results = []
for (orig_num, orig_toks), (gen_num, gen_toks) in tqdm(zip(orig_list, gen_list), desc="Parsing expressions", total=N):
assert orig_num == gen_num
aed = distance(orig_toks, gen_toks) / (len(orig_toks) + len(gen_toks))
all_aed.append(aed)
res = {"id": gen_num, "aed": aed, "matched": False, "parsed": False}
if aed == 0:
parsed.append(orig_num)
exact_match.append(orig_num)
structure_match.append(orig_num)
res["parsed"] = True
res["matched"] = "Exact"
results.append(res)
continue
if do_structure_match(orig_toks, gen_toks):
structure_match.append(orig_num)
res["matched"] = "Structure"
if "<<unk>>" in orig_toks:
# Why this happened?
res["parsed"] = False
res["matched"] = False
results.append(res)
continue
if args.postfix:
orig_expr = parse_postfix_to_sympy(orig_toks)
else:
orig_expr = parse_prefix_to_sympy(orig_toks)
try:
if args.postfix:
gen_expr = parse_postfix_to_sympy(gen_toks)
else:
gen_expr = parse_prefix_to_sympy(gen_toks)
res["parsed"] = True
except: # Exception as e:
gen_errors.append(gen_num)
results.append(res)
continue
parsed.append(gen_num)
orig_exprs[gen_num] = orig_expr
gen_exprs[gen_num] = gen_expr
if orig_expr == gen_expr:
exact_match.append(gen_num)
res["matched"] = "Exact"
elif args.simplify and do_simplify_match(orig_expr, gen_expr):
simplify_match.append(gen_num)
res["matched"] = "Simplify"
results.append(res)
with open(args.r, "w") as resf:
for res in results:
resf.write("{id} {aed} {parsed} {matched}\n".format(**res))
resf.write("\n")
print("Total", N, file=resf)
print("Parse error", len(gen_errors), percent(len(gen_errors), N), file=resf)
print("Exact match", len(exact_match), percent(len(exact_match), N), file=resf)
print("Structure match", len(structure_match), percent(len(structure_match), N), file=resf)
if args.simplify:
print("Simplify match", len(simplify_match), percent(len(simplify_match), N), file=resf)
print("Avg SED", sum(all_aed) / len(all_aed), max(all_aed), file=resf)
# print("Avg GED", sum(all_ged) / len(all_ged), max(all_ged), file=resf)