File size: 4,840 Bytes
7145fd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import sympy as sp
import sys
import re
from tqdm import tqdm
from Levenshtein import distance
import networkx as nx
from networkx import graph_edit_distance

from .parser import parse_prefix_to_sympy, parse_postfix_to_sympy, isint

def percent(a, n):
    return f"{a/n*100:0.1f}%"

def do_simplify_match(orig_expr, gen_expr):
    orig_simp = sp.simplify(orig_expr)
    gen_simp = sp.simplify(gen_expr)
    if orig_simp == gen_simp:
        return True
    return False

def do_structure_match(orig_toks, gen_toks):
    def _isconst(t):
        return re.match(r"c[0-9]+", t)
    def _isvar(t):
        return re.match(r"x[0-9]+", t)
    if len(orig_toks) != len(gen_toks):
        return False
    for orig, gen in zip(orig_toks, gen_toks):
        if (_isconst(orig) and _isconst(gen)) \
                or (_isvar(orig) and _isvar(gen)) \
                or (isint(orig) and isint(gen)) \
                or (orig.startswith("INT") and gen.startswith("INT")) \
                or (orig == gen):
            continue
        # Mismatched
        return False
    return True

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser("Check generated expressions")
    parser.add_argument("-g", required=True, help="Generated expressions file")
    parser.add_argument("-r", required=True, help="Results file")
    parser.add_argument("--simplify", action="store_true", default=False)
    parser.add_argument("--postfix", action="store_true", default=False)
    args = parser.parse_args()
    

    orig_list = []
    gen_list = []
    with open(args.g, 'r') as f:
        for line in tqdm(f, desc="Reading file"):
            comps = line.strip().split("\t")
            if line[0] == 'T':
                num = int(comps[0][2:])
                tokens = comps[1].split(" ")
                orig_list.append((num, tokens))
            elif line[0] == 'H':
                num = int(comps[0][2:])
                tokens = comps[2].split(" ")
                gen_list.append((num, tokens))

    N = len(orig_list)
    gen_errors = []
    parsed = []
    exact_match = []
    structure_match = []
    simplify_match = []

    orig_exprs = {}
    gen_exprs = {}

    all_aed = []
    # all_ged = []

    results = []

    for (orig_num, orig_toks), (gen_num, gen_toks) in tqdm(zip(orig_list, gen_list), desc="Parsing expressions", total=N):
        assert orig_num == gen_num
        aed = distance(orig_toks, gen_toks) / (len(orig_toks) + len(gen_toks))
        all_aed.append(aed)
        res = {"id": gen_num, "aed": aed, "matched": False, "parsed": False}

        if aed == 0:
            parsed.append(orig_num)
            exact_match.append(orig_num)
            structure_match.append(orig_num)
            res["parsed"] = True
            res["matched"] = "Exact"
            results.append(res)
            continue

        if do_structure_match(orig_toks, gen_toks):
            structure_match.append(orig_num)
            res["matched"] = "Structure"

        if "<<unk>>" in orig_toks:
            # Why this happened?
            res["parsed"] = False
            res["matched"] = False
            results.append(res)
            continue

        if args.postfix:
            orig_expr = parse_postfix_to_sympy(orig_toks)
        else:
            orig_expr = parse_prefix_to_sympy(orig_toks)
        try:
            if args.postfix:
                gen_expr = parse_postfix_to_sympy(gen_toks)
            else:
                gen_expr = parse_prefix_to_sympy(gen_toks)
            res["parsed"] = True
        except: # Exception as e:
            gen_errors.append(gen_num)
            results.append(res)
            continue

        parsed.append(gen_num)
        orig_exprs[gen_num] = orig_expr
        gen_exprs[gen_num] = gen_expr

        if orig_expr == gen_expr:
            exact_match.append(gen_num)
            res["matched"] = "Exact"
        elif args.simplify and do_simplify_match(orig_expr, gen_expr):
            simplify_match.append(gen_num)
            res["matched"] = "Simplify"
        results.append(res)

    with open(args.r, "w") as resf:
        for res in results:
            resf.write("{id} {aed} {parsed} {matched}\n".format(**res))
        resf.write("\n")
        print("Total", N, file=resf)
        print("Parse error", len(gen_errors), percent(len(gen_errors), N), file=resf)
        print("Exact match", len(exact_match), percent(len(exact_match), N), file=resf)
        print("Structure match", len(structure_match), percent(len(structure_match), N), file=resf)
        if args.simplify:
            print("Simplify match", len(simplify_match), percent(len(simplify_match), N), file=resf)
        print("Avg SED", sum(all_aed) / len(all_aed), max(all_aed), file=resf)
        # print("Avg GED", sum(all_ged) / len(all_ged), max(all_ged), file=resf)