import sympy as sp import sys import re from tqdm import tqdm from Levenshtein import distance import networkx as nx from networkx import graph_edit_distance from .parser import parse_prefix_to_sympy, parse_postfix_to_sympy, isint def percent(a, n): return f"{a/n*100:0.1f}%" def do_simplify_match(orig_expr, gen_expr): orig_simp = sp.simplify(orig_expr) gen_simp = sp.simplify(gen_expr) if orig_simp == gen_simp: return True return False def do_structure_match(orig_toks, gen_toks): def _isconst(t): return re.match(r"c[0-9]+", t) def _isvar(t): return re.match(r"x[0-9]+", t) if len(orig_toks) != len(gen_toks): return False for orig, gen in zip(orig_toks, gen_toks): if (_isconst(orig) and _isconst(gen)) \ or (_isvar(orig) and _isvar(gen)) \ or (isint(orig) and isint(gen)) \ or (orig.startswith("INT") and gen.startswith("INT")) \ or (orig == gen): continue # Mismatched return False return True if __name__ == "__main__": import argparse parser = argparse.ArgumentParser("Check generated expressions") parser.add_argument("-g", required=True, help="Generated expressions file") parser.add_argument("-r", required=True, help="Results file") parser.add_argument("--simplify", action="store_true", default=False) parser.add_argument("--postfix", action="store_true", default=False) args = parser.parse_args() orig_list = [] gen_list = [] with open(args.g, 'r') as f: for line in tqdm(f, desc="Reading file"): comps = line.strip().split("\t") if line[0] == 'T': num = int(comps[0][2:]) tokens = comps[1].split(" ") orig_list.append((num, tokens)) elif line[0] == 'H': num = int(comps[0][2:]) tokens = comps[2].split(" ") gen_list.append((num, tokens)) N = len(orig_list) gen_errors = [] parsed = [] exact_match = [] structure_match = [] simplify_match = [] orig_exprs = {} gen_exprs = {} all_aed = [] # all_ged = [] results = [] for (orig_num, orig_toks), (gen_num, gen_toks) in tqdm(zip(orig_list, gen_list), desc="Parsing expressions", total=N): assert orig_num == gen_num aed = distance(orig_toks, gen_toks) / (len(orig_toks) + len(gen_toks)) all_aed.append(aed) res = {"id": gen_num, "aed": aed, "matched": False, "parsed": False} if aed == 0: parsed.append(orig_num) exact_match.append(orig_num) structure_match.append(orig_num) res["parsed"] = True res["matched"] = "Exact" results.append(res) continue if do_structure_match(orig_toks, gen_toks): structure_match.append(orig_num) res["matched"] = "Structure" if "<>" in orig_toks: # Why this happened? res["parsed"] = False res["matched"] = False results.append(res) continue if args.postfix: orig_expr = parse_postfix_to_sympy(orig_toks) else: orig_expr = parse_prefix_to_sympy(orig_toks) try: if args.postfix: gen_expr = parse_postfix_to_sympy(gen_toks) else: gen_expr = parse_prefix_to_sympy(gen_toks) res["parsed"] = True except: # Exception as e: gen_errors.append(gen_num) results.append(res) continue parsed.append(gen_num) orig_exprs[gen_num] = orig_expr gen_exprs[gen_num] = gen_expr if orig_expr == gen_expr: exact_match.append(gen_num) res["matched"] = "Exact" elif args.simplify and do_simplify_match(orig_expr, gen_expr): simplify_match.append(gen_num) res["matched"] = "Simplify" results.append(res) with open(args.r, "w") as resf: for res in results: resf.write("{id} {aed} {parsed} {matched}\n".format(**res)) resf.write("\n") print("Total", N, file=resf) print("Parse error", len(gen_errors), percent(len(gen_errors), N), file=resf) print("Exact match", len(exact_match), percent(len(exact_match), N), file=resf) print("Structure match", len(structure_match), percent(len(structure_match), N), file=resf) if args.simplify: print("Simplify match", len(simplify_match), percent(len(simplify_match), N), file=resf) print("Avg SED", sum(all_aed) / len(all_aed), max(all_aed), file=resf) # print("Avg GED", sum(all_ged) / len(all_ged), max(all_ged), file=resf)