import sympy as sp import numpy as np import warnings from sympy.abc import x import sys import json from tqdm import tqdm from remend.tools.parser import parse_prefix_to_sympy warnings.simplefilter("ignore") def percent(a, n): return f"{a/n*100:0.1f}%" def do_eval_match(orig_expr, gen_expr): try: origl = sp.lambdify(x, orig_expr) genl = sp.lambdify(x, gen_expr) count = 0 for v in np.arange(0.2, 1, 0.01): o = origl(v) g = genl(v) if o != o or o == float('inf'): continue if g != g or g == float('inf'): continue # if type(o) != np.float64 or type(g) != np.float64: # print(orig_expr, o, gen_expr, g) # return False if abs((o-g)/o) > 1e-5: return False count += 1 except: return False return count >= 5 if __name__ == "__main__": import argparse parser = argparse.ArgumentParser("Check generated expressions") parser.add_argument("-g", required=True, help="Generated expressions file") parser.add_argument("-i", required=True, help="Info file") parser.add_argument("-r", required=True, help="Results file") args = parser.parse_args() gens = [] with open(args.g, 'r') as genf, open(args.i) as infof: for line in tqdm(genf, desc="Reading file"): comps = line.strip().split("\t") if line[0] == 'H': num = int(comps[0][2:]) tokens = comps[2].split(" ") info = next(infof) info = json.loads(info.strip()) if info["eqn"] == "": continue gens.append((num, tokens, info)) parsed = [] matched = [] results = [] for n, toks, info in tqdm(gens, desc="Evaluating expressions"): res = {"id": n, "parsed": False, "matched": False, "orig": "", "gen": ""} if "<>" in toks: # Not parsed results.append(res) continue try: gen_expr = parse_prefix_to_sympy(toks) except Exception as e: # Not parsed results.append(res) continue res["parsed"] = True parsed.append(n) const = info["constants"] gen_expr = gen_expr.subs([(sp.Symbol("k"+c), const[c]) for c in const]) orig_expr = sp.parse_expr(info["eqn"], local_dict={"x0":x}) res["orig"] = str(orig_expr) res["gen"] = str(gen_expr) if not do_eval_match(orig_expr, gen_expr): results.append(res) continue res["matched"] = True matched.append(n) results.append(res) with open(args.r, "w") as resf: for res in results: resf.write("{id} {parsed} {matched} \"{orig}\" \"{gen}\"\n".format(**res)) resf.write("\n") N = len(gens) print("Total", N, file=resf) print("Parsed", len(parsed), percent(len(parsed), N), file=resf) print("Matched", len(matched), percent(len(matched), N), file=resf)