REMEND / remend /implementation.py
udiboy1209's picture
Add REMEND python module
7145fd6
import sympy as sp
from sympy.codegen import ast
import itertools as it
import networkx as nx
from .parser import OPERATORS, sympy_to_dag
from .util import DecodeError
def isnum(s):
try:
float(s)
return True
except ValueError:
return False
class Implementor:
def __init__(self, expr, constants={}, dtype="double"):
self.expr = expr
self.constants = constants
self.cdtype = dtype
self.cpf = "lf" if dtype == "double" else "f"
self.fdtype = "double precision" if dtype == "double" else "real"
def implement(self, impl):
if impl == "dag_c":
return self.dag_to_c_impl()
elif impl == "cse_c":
return self.sympy_cse_c_impl()
elif impl == "dag_fortran":
return self.dag_to_fortran_impl()
elif impl == "cse_fortran":
return self.sympy_cse_fortran_impl()
def op_c_impl(self, f, children):
if f == "add":
return " + ".join(children);
elif f == "mul":
return " * ".join(children);
elif f == "pow":
assert len(children) == 2
if self.cdtype == "double":
return f"pow({children[0]}, {children[1]})"
else:
return f"powf({children[0]}, {children[1]})"
elif f == "ln":
assert len(children) == 1
if self.cdtype == "double":
return f"log({children[0]})"
else:
return f"logf({children[0]})"
else:
if f in OPERATORS and OPERATORS[f][1] == 1:
assert len(children) == 1
if self.cdtype == "double":
return f"{f}({children[0]})"
else:
return f"{f}f({children[0]})"
else:
raise DecodeError(f"C impl: operation {f} not handled")
def op_f_impl(self, f, children):
if f == "add":
j = ")+(".join(children)
return "(" + j + ")"
elif f == "mul":
j = ")*(".join(children)
return "(" + j + ")"
elif f == "pow":
assert len(children) == 2
return f"({children[0]})**({children[1]})"
elif f == "ln":
assert len(children) == 1
return f"log({children[0]})"
else:
if f in OPERATORS and OPERATORS[f][1] == 1:
assert len(children) == 1
return f"{f}({children[0]})"
else:
raise DecodeError(f"F impl: operation {f} not handled")
def full_c_code(self, body):
pre = f"#include <stdio.h>\n#include <math.h>\n{self.cdtype} myfunc({self.cdtype} x) {{"
post = f"}}\nint main() {{ {self.cdtype} x; scanf(\"%{self.cpf}\", &x); printf(\"%{self.cpf}\", myfunc(x)); }}"
return f"{pre}\n{body}\n{post}"
def full_f_code(self, body):
pre = "function myfunc(x) result(y)\nimplicit none\n" + \
f"{self.fdtype}, intent(in) :: x\n{self.fdtype} :: y, E, pi\n"
post = "end function myfunc\nprogram main\nimplicit none\n" + \
f"{self.fdtype} :: x\n{self.fdtype} :: myfunc\n" + \
"read(*, *) x\nprint *, \"y is:\", myfunc(x)\nend program main"
return f"{pre}\n{body}\n{post}"
def dag_to_c_impl(self):
dag = sympy_to_dag(self.expr, csuf="F" if self.cdtype == "float" else "")
cstr = ""
added_pi, added_E = False, False
for c in self.constants:
cstr += f"{self.cdtype} {c} = {self.constants[c]};\n"
varidx = it.count()
for node in reversed(list(nx.topological_sort(dag))):
label = dag.nodes[node]["label"]
children = [dag.nodes[n]["var"] for n in dag.adj[node]]
if len(children) == 0:
if label == "pi":
if self.cdtype == "float" and not added_pi:
cstr += "const float pi = 3.14159265F;\n"
added_pi = True
else:
label = "M_PI"
elif label == "E":
if self.cdtype == "float" and not added_E:
cstr += "const float E = 2.71828183F;\n"
added_E = True
else:
label = "M_E"
dag.nodes[node]["var"] = label
continue
varname = f"t{next(varidx)}"
cexpr = self.op_c_impl(label, children)
dag.nodes[node]["var"] = varname
cstr += f"{self.cdtype} {varname} = {cexpr};\n"
retname = varname
cstr += f"return {retname};\n"
return self.full_c_code(cstr)
def dag_to_fortran_impl(self):
csuf = "" if self.fdtype == "real" else "d0"
dag = sympy_to_dag(self.expr, csuf=csuf)
varstr = ""
fstr = "parameter E = 2.71828183\nparameter pi = 3.14159265\n"
for c in self.constants:
varstr += f"{self.fdtype} :: {c}\n"
fstr += f"parameter {c} = {self.constants[c]}{csuf}\n"
varidx = it.count()
allvars = []
for node in reversed(list(nx.topological_sort(dag))):
label = dag.nodes[node]["label"]
children = [dag.nodes[n]["var"] for n in dag.adj[node]]
if len(children) == 0:
dag.nodes[node]["var"] = label
continue
varname = f"t{next(varidx)}"
fexpr = self.op_f_impl(label, children)
dag.nodes[node]["var"] = varname
fstr += f"{varname} = {fexpr}\n"
retname = varname
varstr += f"{self.fdtype} :: {varname}\n"
fstr += f"y = {retname};\n"
fstr = varstr + "\n" + fstr
return self.full_f_code(fstr)
def sympy_cse_c_impl(self):
if self.cdtype == "float":
extraargs = {
"type_aliases": {ast.real: ast.float32},
"math_macros": {},
}
else:
extraargs = {}
cstr = ""
for c in self.constants:
cstr += f"{self.cdtype} {c} = {self.constants[c]};\n"
xvars, xpr = sp.cse(self.expr)
for vname, vxpr in xvars:
code = sp.ccode(vxpr, assign_to=vname.name, **extraargs)
cstr += f"{self.cdtype} {vname.name}; {code};\n"
assert len(xpr) == 1
code = sp.ccode(xpr[0], assign_to="y", **extraargs)
cstr += f"{self.cdtype} y; {code}; return y;\n"
return self.full_c_code(cstr)
def sympy_cse_fortran_impl(self):
csuf = "" if self.fdtype == "real" else "d0"
varstr = ""
fstr = ""
for c in self.constants:
varstr += f"{self.fdtype} :: {c}\n"
fstr += f"parameter {c} = {self.constants[c]}{csuf}\n"
xvars, xpr = sp.cse(self.expr)
for vname, vxpr in xvars:
varstr += f"{self.fdtype} :: {vname.name}\n"
fstr += sp.fcode(vxpr, assign_to=vname.name, standard=95, source_format="free") + "\n"
assert len(xpr) == 1
fstr += sp.fcode(xpr[0], assign_to="y", standard=95, source_format="free") + "\n"
fstr = varstr + "\n" + fstr
if self.fdtype == "real":
# Hack to fix sympy generation
fstr = fstr.replace("d0", "")
return self.full_f_code(fstr)
# For testing only
if __name__ == "__main__":
from .parser import parse_prefix_to_sympy, sympy_to_dag
prefs = "add mul div INT+ 1 INT+ 5 x mul div INT+ 1 INT+ 5 mul x tan pow x INT+ 2".split(" ")
exp = parse_prefix_to_sympy(prefs)
impl = Implementor(exp, dtype="float")
print("DAG C:")
print(impl.dag_to_c_impl())
print("DAG Fortran:")
print(impl.dag_to_fortran_impl())
print("CSE C:")
print(impl.sympy_cse_c_impl())
print("CSE Fortran:")
print(impl.sympy_cse_fortran_impl())