|
import sympy as sp |
|
import numpy as np |
|
import warnings |
|
from sympy.abc import x |
|
import sys |
|
import json |
|
from tqdm import tqdm |
|
|
|
from remend.tools.parser import parse_prefix_to_sympy |
|
|
|
warnings.simplefilter("ignore") |
|
|
|
def percent(a, n): |
|
return f"{a/n*100:0.1f}%" |
|
|
|
def do_eval_match(orig_expr, gen_expr): |
|
try: |
|
origl = sp.lambdify(x, orig_expr) |
|
genl = sp.lambdify(x, gen_expr) |
|
count = 0 |
|
|
|
for v in np.arange(0.2, 1, 0.01): |
|
o = origl(v) |
|
g = genl(v) |
|
if o != o or o == float('inf'): |
|
continue |
|
if g != g or g == float('inf'): |
|
continue |
|
|
|
|
|
|
|
if abs((o-g)/o) > 1e-5: |
|
return False |
|
count += 1 |
|
except: |
|
return False |
|
return count >= 5 |
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
parser = argparse.ArgumentParser("Check generated expressions") |
|
parser.add_argument("-g", required=True, help="Generated expressions file") |
|
parser.add_argument("-i", required=True, help="Info file") |
|
parser.add_argument("-r", required=True, help="Results file") |
|
args = parser.parse_args() |
|
|
|
gens = [] |
|
with open(args.g, 'r') as genf, open(args.i) as infof: |
|
for line in tqdm(genf, desc="Reading file"): |
|
comps = line.strip().split("\t") |
|
if line[0] == 'H': |
|
num = int(comps[0][2:]) |
|
tokens = comps[2].split(" ") |
|
info = next(infof) |
|
info = json.loads(info.strip()) |
|
if info["eqn"] == "": |
|
continue |
|
gens.append((num, tokens, info)) |
|
|
|
parsed = [] |
|
matched = [] |
|
results = [] |
|
|
|
for n, toks, info in tqdm(gens, desc="Evaluating expressions"): |
|
res = {"id": n, "parsed": False, "matched": False, "orig": "", "gen": ""} |
|
if "<<unk>>" in toks: |
|
|
|
results.append(res) |
|
continue |
|
try: |
|
gen_expr = parse_prefix_to_sympy(toks) |
|
except Exception as e: |
|
|
|
results.append(res) |
|
continue |
|
|
|
res["parsed"] = True |
|
parsed.append(n) |
|
const = info["constants"] |
|
|
|
gen_expr = gen_expr.subs([(sp.Symbol("k"+c), const[c]) for c in const]) |
|
orig_expr = sp.parse_expr(info["eqn"], local_dict={"x0":x}) |
|
res["orig"] = str(orig_expr) |
|
res["gen"] = str(gen_expr) |
|
|
|
if not do_eval_match(orig_expr, gen_expr): |
|
results.append(res) |
|
continue |
|
res["matched"] = True |
|
matched.append(n) |
|
results.append(res) |
|
|
|
with open(args.r, "w") as resf: |
|
for res in results: |
|
resf.write("{id} {parsed} {matched} \"{orig}\" \"{gen}\"\n".format(**res)) |
|
resf.write("\n") |
|
N = len(gens) |
|
print("Total", N, file=resf) |
|
print("Parsed", len(parsed), percent(len(parsed), N), file=resf) |
|
print("Matched", len(matched), percent(len(matched), N), file=resf) |
|
|