|
|
import sympy as sp |
|
|
import networkx as nx |
|
|
import itertools as it |
|
|
import sys |
|
|
|
|
|
from .util import DecodeError, sympy_expr_ok |
|
|
|
|
|
OPERATORS = { |
|
|
|
|
|
'add': (lambda a,b: a+b, 2), |
|
|
'sub': (lambda a,b: a-b, 2), |
|
|
'mul': (lambda a,b: a*b, 2), |
|
|
'div': (lambda a,b: a/b, 2), |
|
|
'pow': (lambda a,b: a**b, 2), |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'sqrt': (lambda a: sp.sqrt(a), 1), |
|
|
'exp': (lambda a: sp.exp(a), 1), |
|
|
'ln': (lambda a: sp.ln(a), 1), |
|
|
|
|
|
|
|
|
|
|
|
'sin': (lambda a: sp.sin(a), 1), |
|
|
'cos': (lambda a: sp.cos(a), 1), |
|
|
'tan': (lambda a: sp.tan(a), 1), |
|
|
'cot': (lambda a: sp.cot(a), 1), |
|
|
'sec': (lambda a: sp.sec(a), 1), |
|
|
'csc': (lambda a: sp.csc(a), 1), |
|
|
|
|
|
'asin': (lambda a: sp.asin(a), 1), |
|
|
'acos': (lambda a: sp.acos(a), 1), |
|
|
'atan': (lambda a: sp.atan(a), 1), |
|
|
'acot': (lambda a: sp.acot(a), 1), |
|
|
'asec': (lambda a: sp.asec(a), 1), |
|
|
'acsc': (lambda a: sp.acsc(a), 1), |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
CONSTANTS = { |
|
|
'E': sp.E, |
|
|
'pi': sp.pi, |
|
|
'0': 0, |
|
|
'1': 1, |
|
|
'2': 2, |
|
|
'3': 3, |
|
|
'4': 4, |
|
|
'5': 5, |
|
|
'6': 6, |
|
|
'7': 7, |
|
|
'8': 8, |
|
|
'9': 9, |
|
|
} |
|
|
|
|
|
VARIABLES = { |
|
|
'x': sp.Symbol('x'), |
|
|
'x0': sp.Symbol('x0'), |
|
|
'x1': sp.Symbol('x1'), |
|
|
|
|
|
'c0': sp.Symbol('c0'), |
|
|
'c1': sp.Symbol('c1'), |
|
|
'c2': sp.Symbol('c2'), |
|
|
'c3': sp.Symbol('c3'), |
|
|
'c4': sp.Symbol('c4'), |
|
|
'c5': sp.Symbol('c5'), |
|
|
'c6': sp.Symbol('c6'), |
|
|
'c7': sp.Symbol('c7'), |
|
|
'c8': sp.Symbol('c8'), |
|
|
'c9': sp.Symbol('c9'), |
|
|
'c10': sp.Symbol('c10'), |
|
|
|
|
|
'k0': sp.Symbol('k0'), |
|
|
'k1': sp.Symbol('k1'), |
|
|
'k2': sp.Symbol('k2'), |
|
|
'k3': sp.Symbol('k3'), |
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
FUNC_TO_OP = { |
|
|
sp.Add: 'add', |
|
|
sp.Mul: 'mul', |
|
|
sp.Pow: 'pow', |
|
|
|
|
|
sp.log: 'ln', |
|
|
sp.sqrt: 'sqrt', |
|
|
sp.exp: 'exp', |
|
|
sp.Abs: 'abs', |
|
|
|
|
|
|
|
|
|
|
|
sp.sin: 'sin', |
|
|
sp.cos: 'cos', |
|
|
sp.tan: 'tan', |
|
|
sp.cot: 'cot', |
|
|
sp.sec: 'sec', |
|
|
sp.csc: 'csc', |
|
|
|
|
|
sp.asin: 'asin', |
|
|
sp.acos: 'acos', |
|
|
sp.atan: 'atan', |
|
|
sp.acot: 'acot', |
|
|
sp.asec: 'asec', |
|
|
sp.acsc: 'acsc', |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
def sympy_func_to_op(f): |
|
|
if f in FUNC_TO_OP: |
|
|
return FUNC_TO_OP[f] |
|
|
else: |
|
|
raise DecodeError(f"Op not found {f}") |
|
|
return str(f) |
|
|
|
|
|
def isint(s): |
|
|
try: |
|
|
int(s) |
|
|
return True |
|
|
except ValueError: |
|
|
return False |
|
|
|
|
|
def reverse_iter_prefix(prefs): |
|
|
n = len(prefs) - 1 |
|
|
|
|
|
|
|
|
currnum = [] |
|
|
while n >= 0: |
|
|
if isint(prefs[n]) or prefs[n] in ["e", "+", "-", "."]: |
|
|
currnum += prefs[n] |
|
|
|
|
|
|
|
|
elif prefs[n][:3] == "INT": |
|
|
parsedint = int("".join(reversed(currnum))) |
|
|
if prefs[n][3] == "+": |
|
|
yield parsedint |
|
|
else: |
|
|
yield -parsedint |
|
|
currnum = [] |
|
|
|
|
|
elif prefs[n][:5] == "FLOAT": |
|
|
parsedfloat = float("".join(reversed(currnum))) |
|
|
if prefs[n][5] == "+": |
|
|
yield parsedfloat |
|
|
else: |
|
|
yield -parsedfloat |
|
|
currnum = [] |
|
|
else: |
|
|
yield prefs[n] |
|
|
n -= 1 |
|
|
|
|
|
def parse_prefix_to_sympy(prefs): |
|
|
stack = [] |
|
|
for val in reverse_iter_prefix(prefs): |
|
|
|
|
|
if val in OPERATORS: |
|
|
spop, numops = OPERATORS[val] |
|
|
operands = [stack.pop() for i in range(numops)] |
|
|
expr = spop(*operands) |
|
|
stack.append(expr) |
|
|
elif val in CONSTANTS: |
|
|
stack.append(CONSTANTS[val]) |
|
|
elif val in VARIABLES: |
|
|
stack.append(VARIABLES[val]) |
|
|
elif type(val) == int or type(val) == float: |
|
|
stack.append(val) |
|
|
elif val == "(" or val == ")": |
|
|
|
|
|
continue |
|
|
else: |
|
|
raise DecodeError(f"{val} invalid") |
|
|
|
|
|
if len(stack) != 1: |
|
|
raise DecodeError(f"Stack not empty, invalid expression: {prefs} || {stack}") |
|
|
expr = stack.pop() |
|
|
if not sympy_expr_ok(expr): |
|
|
raise DecodeError("Complex or infinite expression") |
|
|
return expr |
|
|
|
|
|
def parse_postfix_to_sympy(prefs): |
|
|
stack = [] |
|
|
postfix = reversed(list(reverse_iter_prefix(prefs))) |
|
|
for val in postfix: |
|
|
if val in OPERATORS: |
|
|
spop, numops = OPERATORS[val] |
|
|
operands = [stack.pop() for i in range(numops)] |
|
|
expr = spop(*operands) |
|
|
stack.append(expr) |
|
|
elif val in CONSTANTS: |
|
|
stack.append(CONSTANTS[val]) |
|
|
elif val in VARIABLES: |
|
|
stack.append(VARIABLES[val]) |
|
|
elif type(val) == int or type(val) == float: |
|
|
stack.append(val) |
|
|
elif val == "(" or val == ")": |
|
|
|
|
|
continue |
|
|
else: |
|
|
raise DecodeError(f"{val} invalid") |
|
|
|
|
|
if len(stack) != 1: |
|
|
raise DecodeError(f"Stack not empty, invalid expression: {prefs} || {stack}") |
|
|
expr = stack.pop() |
|
|
if not sympy_expr_ok(expr): |
|
|
raise DecodeError("Complex or infinite expression") |
|
|
return expr |
|
|
|
|
|
|
|
|
def parse_prefix_to_tree(prefs): |
|
|
tree = nx.DiGraph() |
|
|
stack = [] |
|
|
newidx = len(prefs) |
|
|
for nidx, val in enumerate(reverse_iter_prefix(prefs)): |
|
|
tree.add_node(nidx, label=val) |
|
|
if val in OPERATORS: |
|
|
_, numops = OPERATORS[val] |
|
|
childs = [stack.pop() for i in range(numops)] |
|
|
if val in {"pow", "sub", "div"}: |
|
|
|
|
|
tree.add_node(newidx, label="lhs") |
|
|
tree.add_node(newidx+1, label="rhs") |
|
|
tree.add_edge(nidx, newidx) |
|
|
tree.add_edge(nidx, newidx+1) |
|
|
tree.add_edge(newidx, childs[0]) |
|
|
tree.add_edge(newidx+1, childs[1]) |
|
|
newidx += 2 |
|
|
else: |
|
|
for c in childs: |
|
|
tree.add_edge(nidx, c) |
|
|
elif val in CONSTANTS or val in VARIABLES or type(val) == int: |
|
|
pass |
|
|
else: |
|
|
raise DecodeError(f"Val {val} invalid") |
|
|
stack.append(nidx) |
|
|
|
|
|
if len(stack) != 1: |
|
|
raise DecodeError(f"Stack not empty, invalid expression: {prefs} || {stack}") |
|
|
|
|
|
return tree, stack.pop() |
|
|
|
|
|
def sympy_to_dag(expression, csuf=""): |
|
|
dag = nx.DiGraph() |
|
|
seen = {} |
|
|
nitr = it.count() |
|
|
|
|
|
def _dfs(node): |
|
|
children = [] |
|
|
for child in node.args: |
|
|
if child in seen: |
|
|
cid = seen[child] |
|
|
else: |
|
|
cid = _dfs(child) |
|
|
children.append(cid) |
|
|
|
|
|
nid = next(nitr) |
|
|
dag.add_node(nid, expr=node) |
|
|
seen[node] = nid |
|
|
for cid in children: |
|
|
dag.add_edge(nid, cid) |
|
|
return nid |
|
|
|
|
|
_dfs(expression) |
|
|
for node in dag.nodes: |
|
|
if len(dag.adj[node]) == 0: |
|
|
e = dag.nodes[node]["expr"] |
|
|
if isinstance(e, sp.Integer): |
|
|
dag.nodes[node]["label"] = f"{e}.0{csuf}" |
|
|
elif isinstance(e, sp.Rational): |
|
|
dag.nodes[node]["label"] = f"{e.p}.0{csuf}/{e.q}.0{csuf}" |
|
|
elif isinstance(e, sp.Float): |
|
|
dag.nodes[node]["label"] = f"{float(e)}{csuf}" |
|
|
else: |
|
|
dag.nodes[node]["label"] = str(e) |
|
|
else: |
|
|
dag.nodes[node]["label"] = sympy_func_to_op(dag.nodes[node]["expr"].func) |
|
|
|
|
|
return dag |
|
|
|
|
|
def sympy_to_prefix(expr): |
|
|
trav = [] |
|
|
|
|
|
def _pre(node): |
|
|
nonlocal trav |
|
|
if isinstance(node, sp.Rational): |
|
|
if node.q != 1: |
|
|
trav.append("div") |
|
|
_pre(node.p) |
|
|
_pre(node.q) |
|
|
else: |
|
|
_pre(node.p) |
|
|
elif isinstance(node, sp.Integer) or isinstance(node, int): |
|
|
v = int(node) |
|
|
if v >= 0: |
|
|
trav.append("INT+") |
|
|
trav.extend(list(str(v))) |
|
|
else: |
|
|
trav.append("INT-") |
|
|
trav.extend(list(str(-v))) |
|
|
elif isinstance(node, sp.Symbol): |
|
|
trav.append(str(node)) |
|
|
elif isinstance(node, sp.Mul): |
|
|
mulargs = [] |
|
|
divargs = [] |
|
|
children = node.args |
|
|
for child in children: |
|
|
if isinstance(child, sp.Pow) and \ |
|
|
isinstance(child.args[1], sp.Integer) and child.args[1] == -1: |
|
|
divargs.append(child.args[0]) |
|
|
else: |
|
|
mulargs.append(child) |
|
|
if len(divargs) > 0: |
|
|
trav.append("div") |
|
|
if len(mulargs) == 0: |
|
|
trav.append("INT+") |
|
|
trav.append("1") |
|
|
|
|
|
for i, child in enumerate(mulargs): |
|
|
if i < len(mulargs) - 1: |
|
|
trav.append("mul") |
|
|
_pre(child) |
|
|
|
|
|
for i, child in enumerate(divargs): |
|
|
if i < len(divargs) - 1: |
|
|
trav.append("mul") |
|
|
_pre(child) |
|
|
elif isinstance(node, sp.Add): |
|
|
addargs = [] |
|
|
subargs = [] |
|
|
children = node.args |
|
|
for child in children: |
|
|
if isinstance(child, sp.Mul) and len(child.args) == 2 and \ |
|
|
isinstance(child.args[1], sp.Integer) and child.args[1] == -1: |
|
|
subargs.append(child.args[0]) |
|
|
elif isinstance(child, sp.Mul) and len(child.args) == 2 and \ |
|
|
isinstance(child.args[0], sp.Integer) and child.args[0] == -1: |
|
|
subargs.append(child.args[1]) |
|
|
else: |
|
|
addargs.append(child) |
|
|
if len(subargs) > 0: |
|
|
trav.append("sub") |
|
|
if len(addargs) == 0: |
|
|
trav.append("INT+") |
|
|
trav.append("0") |
|
|
|
|
|
for i, child in enumerate(addargs): |
|
|
if i < len(addargs) - 1: |
|
|
trav.append("add") |
|
|
_pre(child) |
|
|
|
|
|
for i, child in enumerate(subargs): |
|
|
if i < len(subargs) - 1: |
|
|
trav.append("add") |
|
|
_pre(child) |
|
|
elif isinstance(node, sp.Float): |
|
|
rep = sp.nsimplify(node, tolerance=1e-7) |
|
|
if isinstance(rep, sp.Integer): |
|
|
_pre(rep) |
|
|
elif isinstance(rep, sp.Rational) and rep.q <= 16: |
|
|
_pre(rep) |
|
|
else: |
|
|
raise DecodeError(f"Float {node} encountered while generating") |
|
|
|
|
|
elif node == sp.E or node == sp.pi: |
|
|
|
|
|
trav.append(str(node)) |
|
|
else: |
|
|
op = sympy_func_to_op(node.func) |
|
|
children = node.args |
|
|
for i, child in enumerate(children): |
|
|
|
|
|
if i == 0 or i < len(children) - 1: |
|
|
trav.append(op) |
|
|
_pre(child) |
|
|
_pre(expr) |
|
|
return trav |
|
|
|
|
|
def constant_fold(expr): |
|
|
q = [expr] |
|
|
cidx = 0 |
|
|
subsmap = {} |
|
|
constmap = {} |
|
|
|
|
|
isconst = lambda e: not any(c.is_symbol for c in e.atoms()) |
|
|
|
|
|
while len(q) > 0: |
|
|
curr_expr = q.pop(0) |
|
|
if isinstance(curr_expr, sp.Number) or isconst(curr_expr): |
|
|
const_expr = curr_expr.evalf() |
|
|
rep = sp.nsimplify(const_expr, [sp.E, sp.pi], tolerance=1e-7) |
|
|
if isinstance(rep, sp.Integer) or \ |
|
|
(isinstance(rep, sp.Rational) and rep.q <= 16) or \ |
|
|
rep == sp.E or rep == sp.pi: |
|
|
subsmap[curr_expr] = rep |
|
|
else: |
|
|
val = float(const_expr) |
|
|
found = False |
|
|
for c in constmap: |
|
|
if abs(val - constmap[c]) < 1e-7: |
|
|
subsmap[curr_expr] = sp.Symbol(c) |
|
|
found = True |
|
|
elif abs(1/val - constmap[c]) < 1e-7: |
|
|
subsmap[curr_expr] = 1/sp.Symbol(c) |
|
|
found = True |
|
|
elif abs(-val - constmap[c]) < 1e-7: |
|
|
subsmap[curr_expr] = -sp.Symbol(c) |
|
|
found = True |
|
|
elif abs(-1/val - constmap[c]) < 1e-7: |
|
|
subsmap[curr_expr] = -1/sp.Symbol(c) |
|
|
found = True |
|
|
if not found: |
|
|
subsmap[curr_expr] = sp.Symbol(f"k{cidx}") |
|
|
constmap[f"k{cidx}"] = val |
|
|
cidx += 1 |
|
|
else: |
|
|
for child in curr_expr.args: |
|
|
q.append(child) |
|
|
|
|
|
return expr.subs(subsmap), constmap |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
prefs = "add mul INT- 1 x mul pow ln INT+ 4 INT- 1 add x mul INT- 1 pow x INT+ 5".split(" ") |
|
|
exp = parse_prefix_to_sympy(prefs) |
|
|
exp = sp.simplify(exp) |
|
|
print(exp) |
|
|
print(constant_fold(exp)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|