REMEND / remend /parser.py
udiboy1209's picture
Add REMEND python module
7145fd6
raw
history blame
14.3 kB
import sympy as sp
import networkx as nx
import itertools as it
import sys
from .util import DecodeError, sympy_expr_ok
OPERATORS = {
# Elementary functions
'add': (lambda a,b: a+b, 2),
'sub': (lambda a,b: a-b, 2),
'mul': (lambda a,b: a*b, 2),
'div': (lambda a,b: a/b, 2),
'pow': (lambda a,b: a**b, 2),
# 'inv': (lambda a: 1/a, 1),
# 'pow2': (lambda a: a**2, 1),
# 'pow3': (lambda a: a**3, 1),
# 'pow4': (lambda a: a**4, 1),
# 'pow5': (lambda a: a**5, 1),
'sqrt': (lambda a: sp.sqrt(a), 1),
'exp': (lambda a: sp.exp(a), 1),
'ln': (lambda a: sp.ln(a), 1),
# 'abs': (lambda a: sp.abs(a), 1),
# 'sign': (lambda a: sp.sign(a), 1),
# Trigonometric Functions
'sin': (lambda a: sp.sin(a), 1),
'cos': (lambda a: sp.cos(a), 1),
'tan': (lambda a: sp.tan(a), 1),
'cot': (lambda a: sp.cot(a), 1),
'sec': (lambda a: sp.sec(a), 1),
'csc': (lambda a: sp.csc(a), 1),
# Trigonometric Inverses
'asin': (lambda a: sp.asin(a), 1),
'acos': (lambda a: sp.acos(a), 1),
'atan': (lambda a: sp.atan(a), 1),
'acot': (lambda a: sp.acot(a), 1),
'asec': (lambda a: sp.asec(a), 1),
'acsc': (lambda a: sp.acsc(a), 1),
# Hyperbolic
# 'sinh': (lambda a: sp.sinh(a), 1),
# 'cosh': (lambda a: sp.cosh(a), 1),
# 'tanh': (lambda a: sp.tanh(a), 1),
}
CONSTANTS = {
'E': sp.E,
'pi': sp.pi,
'0': 0,
'1': 1,
'2': 2,
'3': 3,
'4': 4,
'5': 5,
'6': 6,
'7': 7,
'8': 8,
'9': 9,
}
VARIABLES = {
'x': sp.Symbol('x'),
'x0': sp.Symbol('x0'),
'x1': sp.Symbol('x1'),
'c0': sp.Symbol('c0'),
'c1': sp.Symbol('c1'),
'c2': sp.Symbol('c2'),
'c3': sp.Symbol('c3'),
'c4': sp.Symbol('c4'),
'c5': sp.Symbol('c5'),
'c6': sp.Symbol('c6'),
'c7': sp.Symbol('c7'),
'c8': sp.Symbol('c8'),
'c9': sp.Symbol('c9'),
'c10': sp.Symbol('c10'),
'k0': sp.Symbol('k0'),
'k1': sp.Symbol('k1'),
'k2': sp.Symbol('k2'),
'k3': sp.Symbol('k3'),
# 'y': sp.Symbol('y'),
# 'z': sp.Symbol('z')
}
FUNC_TO_OP = {
sp.Add: 'add',
sp.Mul: 'mul',
sp.Pow: 'pow',
sp.log: 'ln',
sp.sqrt: 'sqrt',
sp.exp: 'exp',
sp.Abs: 'abs',
# 'abs': (lambda a: sp.abs(a), 1),
# 'sign': (lambda a: sp.sign(a), 1),
# Trigonometric Functions
sp.sin: 'sin',
sp.cos: 'cos',
sp.tan: 'tan',
sp.cot: 'cot',
sp.sec: 'sec',
sp.csc: 'csc',
# Trigonometric Inverses
sp.asin: 'asin',
sp.acos: 'acos',
sp.atan: 'atan',
sp.acot: 'acot',
sp.asec: 'asec',
sp.acsc: 'acsc',
# Hyperbolic
# sp.cosh: 'cosh',
# sp.sinh: 'sinh',
# sp.tanh: 'tanh'
}
def sympy_func_to_op(f):
if f in FUNC_TO_OP:
return FUNC_TO_OP[f]
else:
raise DecodeError(f"Op not found {f}")
return str(f)
def isint(s):
try:
int(s)
return True
except ValueError:
return False
def reverse_iter_prefix(prefs):
n = len(prefs) - 1
# currnum = 0
# currpow = 1
currnum = []
while n >= 0:
if isint(prefs[n]) or prefs[n] in ["e", "+", "-", "."]:
currnum += prefs[n]
# currnum += currpow * int(prefs[n])
# currpow *= 10
elif prefs[n][:3] == "INT":
parsedint = int("".join(reversed(currnum)))
if prefs[n][3] == "+":
yield parsedint
else:
yield -parsedint
currnum = []
# currpow = 1
elif prefs[n][:5] == "FLOAT":
parsedfloat = float("".join(reversed(currnum)))
if prefs[n][5] == "+":
yield parsedfloat
else:
yield -parsedfloat
currnum = []
else:
yield prefs[n]
n -= 1
def parse_prefix_to_sympy(prefs):
stack = []
for val in reverse_iter_prefix(prefs):
# print(stack, val)
if val in OPERATORS:
spop, numops = OPERATORS[val]
operands = [stack.pop() for i in range(numops)]
expr = spop(*operands)
stack.append(expr)
elif val in CONSTANTS:
stack.append(CONSTANTS[val])
elif val in VARIABLES:
stack.append(VARIABLES[val])
elif type(val) == int or type(val) == float:
stack.append(val)
elif val == "(" or val == ")":
# Simply ignore brackets
continue
else:
raise DecodeError(f"{val} invalid")
if len(stack) != 1:
raise DecodeError(f"Stack not empty, invalid expression: {prefs} || {stack}")
expr = stack.pop()
if not sympy_expr_ok(expr):
raise DecodeError("Complex or infinite expression")
return expr
def parse_postfix_to_sympy(prefs):
stack = []
postfix = reversed(list(reverse_iter_prefix(prefs)))
for val in postfix:
if val in OPERATORS:
spop, numops = OPERATORS[val]
operands = [stack.pop() for i in range(numops)]
expr = spop(*operands)
stack.append(expr)
elif val in CONSTANTS:
stack.append(CONSTANTS[val])
elif val in VARIABLES:
stack.append(VARIABLES[val])
elif type(val) == int or type(val) == float:
stack.append(val)
elif val == "(" or val == ")":
# Simply ignore brackets
continue
else:
raise DecodeError(f"{val} invalid")
if len(stack) != 1:
raise DecodeError(f"Stack not empty, invalid expression: {prefs} || {stack}")
expr = stack.pop()
if not sympy_expr_ok(expr):
raise DecodeError("Complex or infinite expression")
return expr
def parse_prefix_to_tree(prefs):
tree = nx.DiGraph()
stack = []
newidx = len(prefs)
for nidx, val in enumerate(reverse_iter_prefix(prefs)):
tree.add_node(nidx, label=val)
if val in OPERATORS:
_, numops = OPERATORS[val]
childs = [stack.pop() for i in range(numops)]
if val in {"pow", "sub", "div"}:
# Ordered children
tree.add_node(newidx, label="lhs")
tree.add_node(newidx+1, label="rhs")
tree.add_edge(nidx, newidx)
tree.add_edge(nidx, newidx+1)
tree.add_edge(newidx, childs[0])
tree.add_edge(newidx+1, childs[1])
newidx += 2
else:
for c in childs:
tree.add_edge(nidx, c)
elif val in CONSTANTS or val in VARIABLES or type(val) == int:
pass
else:
raise DecodeError(f"Val {val} invalid")
stack.append(nidx)
if len(stack) != 1:
raise DecodeError(f"Stack not empty, invalid expression: {prefs} || {stack}")
return tree, stack.pop() # Root node
def sympy_to_dag(expression, csuf=""):
dag = nx.DiGraph()
seen = {}
nitr = it.count()
def _dfs(node):
children = []
for child in node.args:
if child in seen:
cid = seen[child]
else:
cid = _dfs(child)
children.append(cid)
nid = next(nitr)
dag.add_node(nid, expr=node)
seen[node] = nid
for cid in children:
dag.add_edge(nid, cid)
return nid
_dfs(expression)
for node in dag.nodes:
if len(dag.adj[node]) == 0:
e = dag.nodes[node]["expr"]
if isinstance(e, sp.Integer):
dag.nodes[node]["label"] = f"{e}.0{csuf}"
elif isinstance(e, sp.Rational):
dag.nodes[node]["label"] = f"{e.p}.0{csuf}/{e.q}.0{csuf}"
elif isinstance(e, sp.Float):
dag.nodes[node]["label"] = f"{float(e)}{csuf}"
else:
dag.nodes[node]["label"] = str(e)
else:
dag.nodes[node]["label"] = sympy_func_to_op(dag.nodes[node]["expr"].func)
return dag
def sympy_to_prefix(expr):
trav = []
def _pre(node):
nonlocal trav
if isinstance(node, sp.Rational):
if node.q != 1:
trav.append("div")
_pre(node.p)
_pre(node.q)
else:
_pre(node.p)
elif isinstance(node, sp.Integer) or isinstance(node, int):
v = int(node)
if v >= 0:
trav.append("INT+")
trav.extend(list(str(v)))
else:
trav.append("INT-")
trav.extend(list(str(-v)))
elif isinstance(node, sp.Symbol):
trav.append(str(node))
elif isinstance(node, sp.Mul):
mulargs = []
divargs = []
children = node.args
for child in children:
if isinstance(child, sp.Pow) and \
isinstance(child.args[1], sp.Integer) and child.args[1] == -1:
divargs.append(child.args[0])
else:
mulargs.append(child)
if len(divargs) > 0:
trav.append("div")
if len(mulargs) == 0:
trav.append("INT+")
trav.append("1")
# Insert numerator
for i, child in enumerate(mulargs):
if i < len(mulargs) - 1:
trav.append("mul")
_pre(child)
# Insert denominator
for i, child in enumerate(divargs):
if i < len(divargs) - 1:
trav.append("mul")
_pre(child)
elif isinstance(node, sp.Add):
addargs = []
subargs = []
children = node.args
for child in children:
if isinstance(child, sp.Mul) and len(child.args) == 2 and \
isinstance(child.args[1], sp.Integer) and child.args[1] == -1:
subargs.append(child.args[0])
elif isinstance(child, sp.Mul) and len(child.args) == 2 and \
isinstance(child.args[0], sp.Integer) and child.args[0] == -1:
subargs.append(child.args[1])
else:
addargs.append(child)
if len(subargs) > 0:
trav.append("sub")
if len(addargs) == 0:
trav.append("INT+")
trav.append("0")
# Insert numerator
for i, child in enumerate(addargs):
if i < len(addargs) - 1:
trav.append("add")
_pre(child)
# Insert denominator
for i, child in enumerate(subargs):
if i < len(subargs) - 1:
trav.append("add")
_pre(child)
elif isinstance(node, sp.Float):
rep = sp.nsimplify(node, tolerance=1e-7)
if isinstance(rep, sp.Integer):
_pre(rep)
elif isinstance(rep, sp.Rational) and rep.q <= 16:
_pre(rep)
else:
raise DecodeError(f"Float {node} encountered while generating")
# trav.append(str(node))
elif node == sp.E or node == sp.pi:
# Transcendental constants
trav.append(str(node))
else:
op = sympy_func_to_op(node.func)
children = node.args
for i, child in enumerate(children):
# Insert op repeatedly to maintain binary tree
if i == 0 or i < len(children) - 1:
trav.append(op)
_pre(child)
_pre(expr)
return trav
def constant_fold(expr):
q = [expr]
cidx = 0
subsmap = {}
constmap = {}
isconst = lambda e: not any(c.is_symbol for c in e.atoms())
while len(q) > 0:
curr_expr = q.pop(0)
if isinstance(curr_expr, sp.Number) or isconst(curr_expr):
const_expr = curr_expr.evalf()
rep = sp.nsimplify(const_expr, [sp.E, sp.pi], tolerance=1e-7)
if isinstance(rep, sp.Integer) or \
(isinstance(rep, sp.Rational) and rep.q <= 16) or \
rep == sp.E or rep == sp.pi:
subsmap[curr_expr] = rep
else:
val = float(const_expr)
found = False
for c in constmap:
if abs(val - constmap[c]) < 1e-7:
subsmap[curr_expr] = sp.Symbol(c)
found = True
elif abs(1/val - constmap[c]) < 1e-7:
subsmap[curr_expr] = 1/sp.Symbol(c)
found = True
elif abs(-val - constmap[c]) < 1e-7:
subsmap[curr_expr] = -sp.Symbol(c)
found = True
elif abs(-1/val - constmap[c]) < 1e-7:
subsmap[curr_expr] = -1/sp.Symbol(c)
found = True
if not found:
subsmap[curr_expr] = sp.Symbol(f"k{cidx}")
constmap[f"k{cidx}"] = val
cidx += 1
else:
for child in curr_expr.args:
q.append(child)
return expr.subs(subsmap), constmap
# For testing only
if __name__ == "__main__":
prefs = "add mul INT- 1 x mul pow ln INT+ 4 INT- 1 add x mul INT- 1 pow x INT+ 5".split(" ")
exp = parse_prefix_to_sympy(prefs)
exp = sp.simplify(exp)
print(exp)
print(constant_fold(exp))
# prefs = "mul x mul pow cos INT+ 4 INT- 3 pow ln INT+ 3 INT- 6".split(" ")
# exp = parse_prefix_to_sympy(prefs)
# print(exp)
# dag = sympy_to_dag(exp)
# exp = sp.parse_expr("(((((x0) + ((x0) - ((-((x0) + (x0))) / ((x0) + (x0))))) * k0) - (-((x0) + (x0)))) / (-((x0) + (x0)))) * ((-((((x0) + ((x0) - ((-((x0) + (x0))) / ((x0) + (x0))))) * k0) - ((x0) + ((x0) - ((-((x0) + (x0))) / ((x0) + (x0))))))) * ((((x0) + ((x0) - ((-((x0) + (x0))) / ((x0) + (x0))))) * k0) - ((x0) + ((x0) - ((-((x0) + (x0))) / ((x0) + (x0)))))))", evaluate=False)
# # print(sympy_to_prefix(exp))
# simp = sp.simplify(exp)
# pre = sympy_to_prefix(simp)
# print(pre)
# repars = parse_prefix_to_sympy(pre)
# print(simp)
# print(repars)