REMEND / remend /check_generated.py
udiboy1209's picture
Add REMEND python module
7145fd6
import sympy as sp
import sys
import re
from tqdm import tqdm
from Levenshtein import distance
import networkx as nx
from networkx import graph_edit_distance
from .parser import parse_prefix_to_sympy, parse_postfix_to_sympy, isint
def percent(a, n):
return f"{a/n*100:0.1f}%"
def do_simplify_match(orig_expr, gen_expr):
orig_simp = sp.simplify(orig_expr)
gen_simp = sp.simplify(gen_expr)
if orig_simp == gen_simp:
return True
return False
def do_structure_match(orig_toks, gen_toks):
def _isconst(t):
return re.match(r"c[0-9]+", t)
def _isvar(t):
return re.match(r"x[0-9]+", t)
if len(orig_toks) != len(gen_toks):
return False
for orig, gen in zip(orig_toks, gen_toks):
if (_isconst(orig) and _isconst(gen)) \
or (_isvar(orig) and _isvar(gen)) \
or (isint(orig) and isint(gen)) \
or (orig.startswith("INT") and gen.startswith("INT")) \
or (orig == gen):
continue
# Mismatched
return False
return True
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser("Check generated expressions")
parser.add_argument("-g", required=True, help="Generated expressions file")
parser.add_argument("-r", required=True, help="Results file")
parser.add_argument("--simplify", action="store_true", default=False)
parser.add_argument("--postfix", action="store_true", default=False)
args = parser.parse_args()
orig_list = []
gen_list = []
with open(args.g, 'r') as f:
for line in tqdm(f, desc="Reading file"):
comps = line.strip().split("\t")
if line[0] == 'T':
num = int(comps[0][2:])
tokens = comps[1].split(" ")
orig_list.append((num, tokens))
elif line[0] == 'H':
num = int(comps[0][2:])
tokens = comps[2].split(" ")
gen_list.append((num, tokens))
N = len(orig_list)
gen_errors = []
parsed = []
exact_match = []
structure_match = []
simplify_match = []
orig_exprs = {}
gen_exprs = {}
all_aed = []
# all_ged = []
results = []
for (orig_num, orig_toks), (gen_num, gen_toks) in tqdm(zip(orig_list, gen_list), desc="Parsing expressions", total=N):
assert orig_num == gen_num
aed = distance(orig_toks, gen_toks) / (len(orig_toks) + len(gen_toks))
all_aed.append(aed)
res = {"id": gen_num, "aed": aed, "matched": False, "parsed": False}
if aed == 0:
parsed.append(orig_num)
exact_match.append(orig_num)
structure_match.append(orig_num)
res["parsed"] = True
res["matched"] = "Exact"
results.append(res)
continue
if do_structure_match(orig_toks, gen_toks):
structure_match.append(orig_num)
res["matched"] = "Structure"
if "<<unk>>" in orig_toks:
# Why this happened?
res["parsed"] = False
res["matched"] = False
results.append(res)
continue
if args.postfix:
orig_expr = parse_postfix_to_sympy(orig_toks)
else:
orig_expr = parse_prefix_to_sympy(orig_toks)
try:
if args.postfix:
gen_expr = parse_postfix_to_sympy(gen_toks)
else:
gen_expr = parse_prefix_to_sympy(gen_toks)
res["parsed"] = True
except: # Exception as e:
gen_errors.append(gen_num)
results.append(res)
continue
parsed.append(gen_num)
orig_exprs[gen_num] = orig_expr
gen_exprs[gen_num] = gen_expr
if orig_expr == gen_expr:
exact_match.append(gen_num)
res["matched"] = "Exact"
elif args.simplify and do_simplify_match(orig_expr, gen_expr):
simplify_match.append(gen_num)
res["matched"] = "Simplify"
results.append(res)
with open(args.r, "w") as resf:
for res in results:
resf.write("{id} {aed} {parsed} {matched}\n".format(**res))
resf.write("\n")
print("Total", N, file=resf)
print("Parse error", len(gen_errors), percent(len(gen_errors), N), file=resf)
print("Exact match", len(exact_match), percent(len(exact_match), N), file=resf)
print("Structure match", len(structure_match), percent(len(structure_match), N), file=resf)
if args.simplify:
print("Simplify match", len(simplify_match), percent(len(simplify_match), N), file=resf)
print("Avg SED", sum(all_aed) / len(all_aed), max(all_aed), file=resf)
# print("Avg GED", sum(all_ged) / len(all_ged), max(all_ged), file=resf)