File size: 4,840 Bytes
			
			7145fd6  | 
								1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144  | 
								import sympy as sp
import sys
import re
from tqdm import tqdm
from Levenshtein import distance
import networkx as nx
from networkx import graph_edit_distance
from .parser import parse_prefix_to_sympy, parse_postfix_to_sympy, isint
def percent(a, n):
    return f"{a/n*100:0.1f}%"
def do_simplify_match(orig_expr, gen_expr):
    orig_simp = sp.simplify(orig_expr)
    gen_simp = sp.simplify(gen_expr)
    if orig_simp == gen_simp:
        return True
    return False
def do_structure_match(orig_toks, gen_toks):
    def _isconst(t):
        return re.match(r"c[0-9]+", t)
    def _isvar(t):
        return re.match(r"x[0-9]+", t)
    if len(orig_toks) != len(gen_toks):
        return False
    for orig, gen in zip(orig_toks, gen_toks):
        if (_isconst(orig) and _isconst(gen)) \
                or (_isvar(orig) and _isvar(gen)) \
                or (isint(orig) and isint(gen)) \
                or (orig.startswith("INT") and gen.startswith("INT")) \
                or (orig == gen):
            continue
        # Mismatched
        return False
    return True
if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser("Check generated expressions")
    parser.add_argument("-g", required=True, help="Generated expressions file")
    parser.add_argument("-r", required=True, help="Results file")
    parser.add_argument("--simplify", action="store_true", default=False)
    parser.add_argument("--postfix", action="store_true", default=False)
    args = parser.parse_args()
    
    orig_list = []
    gen_list = []
    with open(args.g, 'r') as f:
        for line in tqdm(f, desc="Reading file"):
            comps = line.strip().split("\t")
            if line[0] == 'T':
                num = int(comps[0][2:])
                tokens = comps[1].split(" ")
                orig_list.append((num, tokens))
            elif line[0] == 'H':
                num = int(comps[0][2:])
                tokens = comps[2].split(" ")
                gen_list.append((num, tokens))
    N = len(orig_list)
    gen_errors = []
    parsed = []
    exact_match = []
    structure_match = []
    simplify_match = []
    orig_exprs = {}
    gen_exprs = {}
    all_aed = []
    # all_ged = []
    results = []
    for (orig_num, orig_toks), (gen_num, gen_toks) in tqdm(zip(orig_list, gen_list), desc="Parsing expressions", total=N):
        assert orig_num == gen_num
        aed = distance(orig_toks, gen_toks) / (len(orig_toks) + len(gen_toks))
        all_aed.append(aed)
        res = {"id": gen_num, "aed": aed, "matched": False, "parsed": False}
        if aed == 0:
            parsed.append(orig_num)
            exact_match.append(orig_num)
            structure_match.append(orig_num)
            res["parsed"] = True
            res["matched"] = "Exact"
            results.append(res)
            continue
        if do_structure_match(orig_toks, gen_toks):
            structure_match.append(orig_num)
            res["matched"] = "Structure"
        if "<<unk>>" in orig_toks:
            # Why this happened?
            res["parsed"] = False
            res["matched"] = False
            results.append(res)
            continue
        if args.postfix:
            orig_expr = parse_postfix_to_sympy(orig_toks)
        else:
            orig_expr = parse_prefix_to_sympy(orig_toks)
        try:
            if args.postfix:
                gen_expr = parse_postfix_to_sympy(gen_toks)
            else:
                gen_expr = parse_prefix_to_sympy(gen_toks)
            res["parsed"] = True
        except: # Exception as e:
            gen_errors.append(gen_num)
            results.append(res)
            continue
        parsed.append(gen_num)
        orig_exprs[gen_num] = orig_expr
        gen_exprs[gen_num] = gen_expr
        if orig_expr == gen_expr:
            exact_match.append(gen_num)
            res["matched"] = "Exact"
        elif args.simplify and do_simplify_match(orig_expr, gen_expr):
            simplify_match.append(gen_num)
            res["matched"] = "Simplify"
        results.append(res)
    with open(args.r, "w") as resf:
        for res in results:
            resf.write("{id} {aed} {parsed} {matched}\n".format(**res))
        resf.write("\n")
        print("Total", N, file=resf)
        print("Parse error", len(gen_errors), percent(len(gen_errors), N), file=resf)
        print("Exact match", len(exact_match), percent(len(exact_match), N), file=resf)
        print("Structure match", len(structure_match), percent(len(structure_match), N), file=resf)
        if args.simplify:
            print("Simplify match", len(simplify_match), percent(len(simplify_match), N), file=resf)
        print("Avg SED", sum(all_aed) / len(all_aed), max(all_aed), file=resf)
        # print("Avg GED", sum(all_ged) / len(all_ged), max(all_ged), file=resf)
 |