REMEND / real_world_dataset /eval_dataset.py
udiboy1209's picture
Add real world dataset
e78b7eb
raw
history blame
3.14 kB
import sympy as sp
import numpy as np
import warnings
from sympy.abc import x
import sys
import json
from tqdm import tqdm
from remend.tools.parser import parse_prefix_to_sympy
warnings.simplefilter("ignore")
def percent(a, n):
return f"{a/n*100:0.1f}%"
def do_eval_match(orig_expr, gen_expr):
try:
origl = sp.lambdify(x, orig_expr)
genl = sp.lambdify(x, gen_expr)
count = 0
for v in np.arange(0.2, 1, 0.01):
o = origl(v)
g = genl(v)
if o != o or o == float('inf'):
continue
if g != g or g == float('inf'):
continue
# if type(o) != np.float64 or type(g) != np.float64:
# print(orig_expr, o, gen_expr, g)
# return False
if abs((o-g)/o) > 1e-5:
return False
count += 1
except:
return False
return count >= 5
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser("Check generated expressions")
parser.add_argument("-g", required=True, help="Generated expressions file")
parser.add_argument("-i", required=True, help="Info file")
parser.add_argument("-r", required=True, help="Results file")
args = parser.parse_args()
gens = []
with open(args.g, 'r') as genf, open(args.i) as infof:
for line in tqdm(genf, desc="Reading file"):
comps = line.strip().split("\t")
if line[0] == 'H':
num = int(comps[0][2:])
tokens = comps[2].split(" ")
info = next(infof)
info = json.loads(info.strip())
if info["eqn"] == "":
continue
gens.append((num, tokens, info))
parsed = []
matched = []
results = []
for n, toks, info in tqdm(gens, desc="Evaluating expressions"):
res = {"id": n, "parsed": False, "matched": False, "orig": "", "gen": ""}
if "<<unk>>" in toks:
# Not parsed
results.append(res)
continue
try:
gen_expr = parse_prefix_to_sympy(toks)
except Exception as e:
# Not parsed
results.append(res)
continue
res["parsed"] = True
parsed.append(n)
const = info["constants"]
gen_expr = gen_expr.subs([(sp.Symbol("k"+c), const[c]) for c in const])
orig_expr = sp.parse_expr(info["eqn"], local_dict={"x0":x})
res["orig"] = str(orig_expr)
res["gen"] = str(gen_expr)
if not do_eval_match(orig_expr, gen_expr):
results.append(res)
continue
res["matched"] = True
matched.append(n)
results.append(res)
with open(args.r, "w") as resf:
for res in results:
resf.write("{id} {parsed} {matched} \"{orig}\" \"{gen}\"\n".format(**res))
resf.write("\n")
N = len(gens)
print("Total", N, file=resf)
print("Parsed", len(parsed), percent(len(parsed), N), file=resf)
print("Matched", len(matched), percent(len(matched), N), file=resf)