import sympy as sp import networkx as nx import itertools as it import sys from .util import DecodeError, sympy_expr_ok OPERATORS = { # Elementary functions 'add': (lambda a,b: a+b, 2), 'sub': (lambda a,b: a-b, 2), 'mul': (lambda a,b: a*b, 2), 'div': (lambda a,b: a/b, 2), 'pow': (lambda a,b: a**b, 2), # 'inv': (lambda a: 1/a, 1), # 'pow2': (lambda a: a**2, 1), # 'pow3': (lambda a: a**3, 1), # 'pow4': (lambda a: a**4, 1), # 'pow5': (lambda a: a**5, 1), 'sqrt': (lambda a: sp.sqrt(a), 1), 'exp': (lambda a: sp.exp(a), 1), 'ln': (lambda a: sp.ln(a), 1), # 'abs': (lambda a: sp.abs(a), 1), # 'sign': (lambda a: sp.sign(a), 1), # Trigonometric Functions 'sin': (lambda a: sp.sin(a), 1), 'cos': (lambda a: sp.cos(a), 1), 'tan': (lambda a: sp.tan(a), 1), 'cot': (lambda a: sp.cot(a), 1), 'sec': (lambda a: sp.sec(a), 1), 'csc': (lambda a: sp.csc(a), 1), # Trigonometric Inverses 'asin': (lambda a: sp.asin(a), 1), 'acos': (lambda a: sp.acos(a), 1), 'atan': (lambda a: sp.atan(a), 1), 'acot': (lambda a: sp.acot(a), 1), 'asec': (lambda a: sp.asec(a), 1), 'acsc': (lambda a: sp.acsc(a), 1), # Hyperbolic # 'sinh': (lambda a: sp.sinh(a), 1), # 'cosh': (lambda a: sp.cosh(a), 1), # 'tanh': (lambda a: sp.tanh(a), 1), } CONSTANTS = { 'E': sp.E, 'pi': sp.pi, '0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9, } VARIABLES = { 'x': sp.Symbol('x'), 'x0': sp.Symbol('x0'), 'x1': sp.Symbol('x1'), 'c0': sp.Symbol('c0'), 'c1': sp.Symbol('c1'), 'c2': sp.Symbol('c2'), 'c3': sp.Symbol('c3'), 'c4': sp.Symbol('c4'), 'c5': sp.Symbol('c5'), 'c6': sp.Symbol('c6'), 'c7': sp.Symbol('c7'), 'c8': sp.Symbol('c8'), 'c9': sp.Symbol('c9'), 'c10': sp.Symbol('c10'), 'k0': sp.Symbol('k0'), 'k1': sp.Symbol('k1'), 'k2': sp.Symbol('k2'), 'k3': sp.Symbol('k3'), # 'y': sp.Symbol('y'), # 'z': sp.Symbol('z') } FUNC_TO_OP = { sp.Add: 'add', sp.Mul: 'mul', sp.Pow: 'pow', sp.log: 'ln', sp.sqrt: 'sqrt', sp.exp: 'exp', sp.Abs: 'abs', # 'abs': (lambda a: sp.abs(a), 1), # 'sign': (lambda a: sp.sign(a), 1), # Trigonometric Functions sp.sin: 'sin', sp.cos: 'cos', sp.tan: 'tan', sp.cot: 'cot', sp.sec: 'sec', sp.csc: 'csc', # Trigonometric Inverses sp.asin: 'asin', sp.acos: 'acos', sp.atan: 'atan', sp.acot: 'acot', sp.asec: 'asec', sp.acsc: 'acsc', # Hyperbolic # sp.cosh: 'cosh', # sp.sinh: 'sinh', # sp.tanh: 'tanh' } def sympy_func_to_op(f): if f in FUNC_TO_OP: return FUNC_TO_OP[f] else: raise DecodeError(f"Op not found {f}") return str(f) def isint(s): try: int(s) return True except ValueError: return False def reverse_iter_prefix(prefs): n = len(prefs) - 1 # currnum = 0 # currpow = 1 currnum = [] while n >= 0: if isint(prefs[n]) or prefs[n] in ["e", "+", "-", "."]: currnum += prefs[n] # currnum += currpow * int(prefs[n]) # currpow *= 10 elif prefs[n][:3] == "INT": parsedint = int("".join(reversed(currnum))) if prefs[n][3] == "+": yield parsedint else: yield -parsedint currnum = [] # currpow = 1 elif prefs[n][:5] == "FLOAT": parsedfloat = float("".join(reversed(currnum))) if prefs[n][5] == "+": yield parsedfloat else: yield -parsedfloat currnum = [] else: yield prefs[n] n -= 1 def parse_prefix_to_sympy(prefs): stack = [] for val in reverse_iter_prefix(prefs): # print(stack, val) if val in OPERATORS: spop, numops = OPERATORS[val] operands = [stack.pop() for i in range(numops)] expr = spop(*operands) stack.append(expr) elif val in CONSTANTS: stack.append(CONSTANTS[val]) elif val in VARIABLES: stack.append(VARIABLES[val]) elif type(val) == int or type(val) == float: stack.append(val) elif val == "(" or val == ")": # Simply ignore brackets continue else: raise DecodeError(f"{val} invalid") if len(stack) != 1: raise DecodeError(f"Stack not empty, invalid expression: {prefs} || {stack}") expr = stack.pop() if not sympy_expr_ok(expr): raise DecodeError("Complex or infinite expression") return expr def parse_postfix_to_sympy(prefs): stack = [] postfix = reversed(list(reverse_iter_prefix(prefs))) for val in postfix: if val in OPERATORS: spop, numops = OPERATORS[val] operands = [stack.pop() for i in range(numops)] expr = spop(*operands) stack.append(expr) elif val in CONSTANTS: stack.append(CONSTANTS[val]) elif val in VARIABLES: stack.append(VARIABLES[val]) elif type(val) == int or type(val) == float: stack.append(val) elif val == "(" or val == ")": # Simply ignore brackets continue else: raise DecodeError(f"{val} invalid") if len(stack) != 1: raise DecodeError(f"Stack not empty, invalid expression: {prefs} || {stack}") expr = stack.pop() if not sympy_expr_ok(expr): raise DecodeError("Complex or infinite expression") return expr def parse_prefix_to_tree(prefs): tree = nx.DiGraph() stack = [] newidx = len(prefs) for nidx, val in enumerate(reverse_iter_prefix(prefs)): tree.add_node(nidx, label=val) if val in OPERATORS: _, numops = OPERATORS[val] childs = [stack.pop() for i in range(numops)] if val in {"pow", "sub", "div"}: # Ordered children tree.add_node(newidx, label="lhs") tree.add_node(newidx+1, label="rhs") tree.add_edge(nidx, newidx) tree.add_edge(nidx, newidx+1) tree.add_edge(newidx, childs[0]) tree.add_edge(newidx+1, childs[1]) newidx += 2 else: for c in childs: tree.add_edge(nidx, c) elif val in CONSTANTS or val in VARIABLES or type(val) == int: pass else: raise DecodeError(f"Val {val} invalid") stack.append(nidx) if len(stack) != 1: raise DecodeError(f"Stack not empty, invalid expression: {prefs} || {stack}") return tree, stack.pop() # Root node def sympy_to_dag(expression, csuf=""): dag = nx.DiGraph() seen = {} nitr = it.count() def _dfs(node): children = [] for child in node.args: if child in seen: cid = seen[child] else: cid = _dfs(child) children.append(cid) nid = next(nitr) dag.add_node(nid, expr=node) seen[node] = nid for cid in children: dag.add_edge(nid, cid) return nid _dfs(expression) for node in dag.nodes: if len(dag.adj[node]) == 0: e = dag.nodes[node]["expr"] if isinstance(e, sp.Integer): dag.nodes[node]["label"] = f"{e}.0{csuf}" elif isinstance(e, sp.Rational): dag.nodes[node]["label"] = f"{e.p}.0{csuf}/{e.q}.0{csuf}" elif isinstance(e, sp.Float): dag.nodes[node]["label"] = f"{float(e)}{csuf}" else: dag.nodes[node]["label"] = str(e) else: dag.nodes[node]["label"] = sympy_func_to_op(dag.nodes[node]["expr"].func) return dag def sympy_to_prefix(expr): trav = [] def _pre(node): nonlocal trav if isinstance(node, sp.Rational): if node.q != 1: trav.append("div") _pre(node.p) _pre(node.q) else: _pre(node.p) elif isinstance(node, sp.Integer) or isinstance(node, int): v = int(node) if v >= 0: trav.append("INT+") trav.extend(list(str(v))) else: trav.append("INT-") trav.extend(list(str(-v))) elif isinstance(node, sp.Symbol): trav.append(str(node)) elif isinstance(node, sp.Mul): mulargs = [] divargs = [] children = node.args for child in children: if isinstance(child, sp.Pow) and \ isinstance(child.args[1], sp.Integer) and child.args[1] == -1: divargs.append(child.args[0]) else: mulargs.append(child) if len(divargs) > 0: trav.append("div") if len(mulargs) == 0: trav.append("INT+") trav.append("1") # Insert numerator for i, child in enumerate(mulargs): if i < len(mulargs) - 1: trav.append("mul") _pre(child) # Insert denominator for i, child in enumerate(divargs): if i < len(divargs) - 1: trav.append("mul") _pre(child) elif isinstance(node, sp.Add): addargs = [] subargs = [] children = node.args for child in children: if isinstance(child, sp.Mul) and len(child.args) == 2 and \ isinstance(child.args[1], sp.Integer) and child.args[1] == -1: subargs.append(child.args[0]) elif isinstance(child, sp.Mul) and len(child.args) == 2 and \ isinstance(child.args[0], sp.Integer) and child.args[0] == -1: subargs.append(child.args[1]) else: addargs.append(child) if len(subargs) > 0: trav.append("sub") if len(addargs) == 0: trav.append("INT+") trav.append("0") # Insert numerator for i, child in enumerate(addargs): if i < len(addargs) - 1: trav.append("add") _pre(child) # Insert denominator for i, child in enumerate(subargs): if i < len(subargs) - 1: trav.append("add") _pre(child) elif isinstance(node, sp.Float): rep = sp.nsimplify(node, tolerance=1e-7) if isinstance(rep, sp.Integer): _pre(rep) elif isinstance(rep, sp.Rational) and rep.q <= 16: _pre(rep) else: raise DecodeError(f"Float {node} encountered while generating") # trav.append(str(node)) elif node == sp.E or node == sp.pi: # Transcendental constants trav.append(str(node)) else: op = sympy_func_to_op(node.func) children = node.args for i, child in enumerate(children): # Insert op repeatedly to maintain binary tree if i == 0 or i < len(children) - 1: trav.append(op) _pre(child) _pre(expr) return trav def constant_fold(expr): q = [expr] cidx = 0 subsmap = {} constmap = {} isconst = lambda e: not any(c.is_symbol for c in e.atoms()) while len(q) > 0: curr_expr = q.pop(0) if isinstance(curr_expr, sp.Number) or isconst(curr_expr): const_expr = curr_expr.evalf() rep = sp.nsimplify(const_expr, [sp.E, sp.pi], tolerance=1e-7) if isinstance(rep, sp.Integer) or \ (isinstance(rep, sp.Rational) and rep.q <= 16) or \ rep == sp.E or rep == sp.pi: subsmap[curr_expr] = rep else: val = float(const_expr) found = False for c in constmap: if abs(val - constmap[c]) < 1e-7: subsmap[curr_expr] = sp.Symbol(c) found = True elif abs(1/val - constmap[c]) < 1e-7: subsmap[curr_expr] = 1/sp.Symbol(c) found = True elif abs(-val - constmap[c]) < 1e-7: subsmap[curr_expr] = -sp.Symbol(c) found = True elif abs(-1/val - constmap[c]) < 1e-7: subsmap[curr_expr] = -1/sp.Symbol(c) found = True if not found: subsmap[curr_expr] = sp.Symbol(f"k{cidx}") constmap[f"k{cidx}"] = val cidx += 1 else: for child in curr_expr.args: q.append(child) return expr.subs(subsmap), constmap # For testing only if __name__ == "__main__": prefs = "add mul INT- 1 x mul pow ln INT+ 4 INT- 1 add x mul INT- 1 pow x INT+ 5".split(" ") exp = parse_prefix_to_sympy(prefs) exp = sp.simplify(exp) print(exp) print(constant_fold(exp)) # prefs = "mul x mul pow cos INT+ 4 INT- 3 pow ln INT+ 3 INT- 6".split(" ") # exp = parse_prefix_to_sympy(prefs) # print(exp) # dag = sympy_to_dag(exp) # exp = sp.parse_expr("(((((x0) + ((x0) - ((-((x0) + (x0))) / ((x0) + (x0))))) * k0) - (-((x0) + (x0)))) / (-((x0) + (x0)))) * ((-((((x0) + ((x0) - ((-((x0) + (x0))) / ((x0) + (x0))))) * k0) - ((x0) + ((x0) - ((-((x0) + (x0))) / ((x0) + (x0))))))) * ((((x0) + ((x0) - ((-((x0) + (x0))) / ((x0) + (x0))))) * k0) - ((x0) + ((x0) - ((-((x0) + (x0))) / ((x0) + (x0)))))))", evaluate=False) # # print(sympy_to_prefix(exp)) # simp = sp.simplify(exp) # pre = sympy_to_prefix(simp) # print(pre) # repars = parse_prefix_to_sympy(pre) # print(simp) # print(repars)