|
import sympy as sp |
|
import sys |
|
import re |
|
from tqdm import tqdm |
|
from Levenshtein import distance |
|
import networkx as nx |
|
from networkx import graph_edit_distance |
|
|
|
from .parser import parse_prefix_to_sympy, parse_postfix_to_sympy, isint |
|
|
|
def percent(a, n): |
|
return f"{a/n*100:0.1f}%" |
|
|
|
def do_simplify_match(orig_expr, gen_expr): |
|
orig_simp = sp.simplify(orig_expr) |
|
gen_simp = sp.simplify(gen_expr) |
|
if orig_simp == gen_simp: |
|
return True |
|
return False |
|
|
|
def do_structure_match(orig_toks, gen_toks): |
|
def _isconst(t): |
|
return re.match(r"c[0-9]+", t) |
|
def _isvar(t): |
|
return re.match(r"x[0-9]+", t) |
|
if len(orig_toks) != len(gen_toks): |
|
return False |
|
for orig, gen in zip(orig_toks, gen_toks): |
|
if (_isconst(orig) and _isconst(gen)) \ |
|
or (_isvar(orig) and _isvar(gen)) \ |
|
or (isint(orig) and isint(gen)) \ |
|
or (orig.startswith("INT") and gen.startswith("INT")) \ |
|
or (orig == gen): |
|
continue |
|
|
|
return False |
|
return True |
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
parser = argparse.ArgumentParser("Check generated expressions") |
|
parser.add_argument("-g", required=True, help="Generated expressions file") |
|
parser.add_argument("-r", required=True, help="Results file") |
|
parser.add_argument("--simplify", action="store_true", default=False) |
|
parser.add_argument("--postfix", action="store_true", default=False) |
|
args = parser.parse_args() |
|
|
|
|
|
orig_list = [] |
|
gen_list = [] |
|
with open(args.g, 'r') as f: |
|
for line in tqdm(f, desc="Reading file"): |
|
comps = line.strip().split("\t") |
|
if line[0] == 'T': |
|
num = int(comps[0][2:]) |
|
tokens = comps[1].split(" ") |
|
orig_list.append((num, tokens)) |
|
elif line[0] == 'H': |
|
num = int(comps[0][2:]) |
|
tokens = comps[2].split(" ") |
|
gen_list.append((num, tokens)) |
|
|
|
N = len(orig_list) |
|
gen_errors = [] |
|
parsed = [] |
|
exact_match = [] |
|
structure_match = [] |
|
simplify_match = [] |
|
|
|
orig_exprs = {} |
|
gen_exprs = {} |
|
|
|
all_aed = [] |
|
|
|
|
|
results = [] |
|
|
|
for (orig_num, orig_toks), (gen_num, gen_toks) in tqdm(zip(orig_list, gen_list), desc="Parsing expressions", total=N): |
|
assert orig_num == gen_num |
|
aed = distance(orig_toks, gen_toks) / (len(orig_toks) + len(gen_toks)) |
|
all_aed.append(aed) |
|
res = {"id": gen_num, "aed": aed, "matched": False, "parsed": False} |
|
|
|
if aed == 0: |
|
parsed.append(orig_num) |
|
exact_match.append(orig_num) |
|
structure_match.append(orig_num) |
|
res["parsed"] = True |
|
res["matched"] = "Exact" |
|
results.append(res) |
|
continue |
|
|
|
if do_structure_match(orig_toks, gen_toks): |
|
structure_match.append(orig_num) |
|
res["matched"] = "Structure" |
|
|
|
if "<<unk>>" in orig_toks: |
|
|
|
res["parsed"] = False |
|
res["matched"] = False |
|
results.append(res) |
|
continue |
|
|
|
if args.postfix: |
|
orig_expr = parse_postfix_to_sympy(orig_toks) |
|
else: |
|
orig_expr = parse_prefix_to_sympy(orig_toks) |
|
try: |
|
if args.postfix: |
|
gen_expr = parse_postfix_to_sympy(gen_toks) |
|
else: |
|
gen_expr = parse_prefix_to_sympy(gen_toks) |
|
res["parsed"] = True |
|
except: |
|
gen_errors.append(gen_num) |
|
results.append(res) |
|
continue |
|
|
|
parsed.append(gen_num) |
|
orig_exprs[gen_num] = orig_expr |
|
gen_exprs[gen_num] = gen_expr |
|
|
|
if orig_expr == gen_expr: |
|
exact_match.append(gen_num) |
|
res["matched"] = "Exact" |
|
elif args.simplify and do_simplify_match(orig_expr, gen_expr): |
|
simplify_match.append(gen_num) |
|
res["matched"] = "Simplify" |
|
results.append(res) |
|
|
|
with open(args.r, "w") as resf: |
|
for res in results: |
|
resf.write("{id} {aed} {parsed} {matched}\n".format(**res)) |
|
resf.write("\n") |
|
print("Total", N, file=resf) |
|
print("Parse error", len(gen_errors), percent(len(gen_errors), N), file=resf) |
|
print("Exact match", len(exact_match), percent(len(exact_match), N), file=resf) |
|
print("Structure match", len(structure_match), percent(len(structure_match), N), file=resf) |
|
if args.simplify: |
|
print("Simplify match", len(simplify_match), percent(len(simplify_match), N), file=resf) |
|
print("Avg SED", sum(all_aed) / len(all_aed), max(all_aed), file=resf) |
|
|
|
|
|
|