|
import sympy as sp |
|
from sympy.codegen import ast |
|
import itertools as it |
|
import networkx as nx |
|
|
|
from .parser import OPERATORS, sympy_to_dag |
|
from .util import DecodeError |
|
|
|
def isnum(s): |
|
try: |
|
float(s) |
|
return True |
|
except ValueError: |
|
return False |
|
|
|
class Implementor: |
|
def __init__(self, expr, constants={}, dtype="double"): |
|
self.expr = expr |
|
self.constants = constants |
|
self.cdtype = dtype |
|
self.cpf = "lf" if dtype == "double" else "f" |
|
self.fdtype = "double precision" if dtype == "double" else "real" |
|
|
|
def implement(self, impl): |
|
if impl == "dag_c": |
|
return self.dag_to_c_impl() |
|
elif impl == "cse_c": |
|
return self.sympy_cse_c_impl() |
|
elif impl == "dag_fortran": |
|
return self.dag_to_fortran_impl() |
|
elif impl == "cse_fortran": |
|
return self.sympy_cse_fortran_impl() |
|
|
|
|
|
def op_c_impl(self, f, children): |
|
if f == "add": |
|
return " + ".join(children); |
|
elif f == "mul": |
|
return " * ".join(children); |
|
elif f == "pow": |
|
assert len(children) == 2 |
|
if self.cdtype == "double": |
|
return f"pow({children[0]}, {children[1]})" |
|
else: |
|
return f"powf({children[0]}, {children[1]})" |
|
elif f == "ln": |
|
assert len(children) == 1 |
|
if self.cdtype == "double": |
|
return f"log({children[0]})" |
|
else: |
|
return f"logf({children[0]})" |
|
else: |
|
if f in OPERATORS and OPERATORS[f][1] == 1: |
|
assert len(children) == 1 |
|
if self.cdtype == "double": |
|
return f"{f}({children[0]})" |
|
else: |
|
return f"{f}f({children[0]})" |
|
else: |
|
raise DecodeError(f"C impl: operation {f} not handled") |
|
|
|
def op_f_impl(self, f, children): |
|
if f == "add": |
|
j = ")+(".join(children) |
|
return "(" + j + ")" |
|
elif f == "mul": |
|
j = ")*(".join(children) |
|
return "(" + j + ")" |
|
elif f == "pow": |
|
assert len(children) == 2 |
|
return f"({children[0]})**({children[1]})" |
|
elif f == "ln": |
|
assert len(children) == 1 |
|
return f"log({children[0]})" |
|
else: |
|
if f in OPERATORS and OPERATORS[f][1] == 1: |
|
assert len(children) == 1 |
|
return f"{f}({children[0]})" |
|
else: |
|
raise DecodeError(f"F impl: operation {f} not handled") |
|
|
|
def full_c_code(self, body): |
|
pre = f"#include <stdio.h>\n#include <math.h>\n{self.cdtype} myfunc({self.cdtype} x) {{" |
|
post = f"}}\nint main() {{ {self.cdtype} x; scanf(\"%{self.cpf}\", &x); printf(\"%{self.cpf}\", myfunc(x)); }}" |
|
return f"{pre}\n{body}\n{post}" |
|
|
|
def full_f_code(self, body): |
|
pre = "function myfunc(x) result(y)\nimplicit none\n" + \ |
|
f"{self.fdtype}, intent(in) :: x\n{self.fdtype} :: y, E, pi\n" |
|
post = "end function myfunc\nprogram main\nimplicit none\n" + \ |
|
f"{self.fdtype} :: x\n{self.fdtype} :: myfunc\n" + \ |
|
"read(*, *) x\nprint *, \"y is:\", myfunc(x)\nend program main" |
|
return f"{pre}\n{body}\n{post}" |
|
|
|
def dag_to_c_impl(self): |
|
dag = sympy_to_dag(self.expr, csuf="F" if self.cdtype == "float" else "") |
|
cstr = "" |
|
added_pi, added_E = False, False |
|
for c in self.constants: |
|
cstr += f"{self.cdtype} {c} = {self.constants[c]};\n" |
|
varidx = it.count() |
|
for node in reversed(list(nx.topological_sort(dag))): |
|
label = dag.nodes[node]["label"] |
|
children = [dag.nodes[n]["var"] for n in dag.adj[node]] |
|
if len(children) == 0: |
|
if label == "pi": |
|
if self.cdtype == "float" and not added_pi: |
|
cstr += "const float pi = 3.14159265F;\n" |
|
added_pi = True |
|
else: |
|
label = "M_PI" |
|
elif label == "E": |
|
if self.cdtype == "float" and not added_E: |
|
cstr += "const float E = 2.71828183F;\n" |
|
added_E = True |
|
else: |
|
label = "M_E" |
|
dag.nodes[node]["var"] = label |
|
continue |
|
varname = f"t{next(varidx)}" |
|
cexpr = self.op_c_impl(label, children) |
|
dag.nodes[node]["var"] = varname |
|
cstr += f"{self.cdtype} {varname} = {cexpr};\n" |
|
retname = varname |
|
cstr += f"return {retname};\n" |
|
return self.full_c_code(cstr) |
|
|
|
def dag_to_fortran_impl(self): |
|
csuf = "" if self.fdtype == "real" else "d0" |
|
dag = sympy_to_dag(self.expr, csuf=csuf) |
|
varstr = "" |
|
fstr = "parameter E = 2.71828183\nparameter pi = 3.14159265\n" |
|
for c in self.constants: |
|
varstr += f"{self.fdtype} :: {c}\n" |
|
fstr += f"parameter {c} = {self.constants[c]}{csuf}\n" |
|
varidx = it.count() |
|
allvars = [] |
|
for node in reversed(list(nx.topological_sort(dag))): |
|
label = dag.nodes[node]["label"] |
|
children = [dag.nodes[n]["var"] for n in dag.adj[node]] |
|
if len(children) == 0: |
|
dag.nodes[node]["var"] = label |
|
continue |
|
varname = f"t{next(varidx)}" |
|
fexpr = self.op_f_impl(label, children) |
|
dag.nodes[node]["var"] = varname |
|
fstr += f"{varname} = {fexpr}\n" |
|
retname = varname |
|
varstr += f"{self.fdtype} :: {varname}\n" |
|
fstr += f"y = {retname};\n" |
|
fstr = varstr + "\n" + fstr |
|
return self.full_f_code(fstr) |
|
|
|
def sympy_cse_c_impl(self): |
|
if self.cdtype == "float": |
|
extraargs = { |
|
"type_aliases": {ast.real: ast.float32}, |
|
"math_macros": {}, |
|
} |
|
else: |
|
extraargs = {} |
|
cstr = "" |
|
for c in self.constants: |
|
cstr += f"{self.cdtype} {c} = {self.constants[c]};\n" |
|
xvars, xpr = sp.cse(self.expr) |
|
for vname, vxpr in xvars: |
|
code = sp.ccode(vxpr, assign_to=vname.name, **extraargs) |
|
cstr += f"{self.cdtype} {vname.name}; {code};\n" |
|
assert len(xpr) == 1 |
|
code = sp.ccode(xpr[0], assign_to="y", **extraargs) |
|
cstr += f"{self.cdtype} y; {code}; return y;\n" |
|
return self.full_c_code(cstr) |
|
|
|
def sympy_cse_fortran_impl(self): |
|
csuf = "" if self.fdtype == "real" else "d0" |
|
varstr = "" |
|
fstr = "" |
|
for c in self.constants: |
|
varstr += f"{self.fdtype} :: {c}\n" |
|
fstr += f"parameter {c} = {self.constants[c]}{csuf}\n" |
|
xvars, xpr = sp.cse(self.expr) |
|
for vname, vxpr in xvars: |
|
varstr += f"{self.fdtype} :: {vname.name}\n" |
|
fstr += sp.fcode(vxpr, assign_to=vname.name, standard=95, source_format="free") + "\n" |
|
assert len(xpr) == 1 |
|
fstr += sp.fcode(xpr[0], assign_to="y", standard=95, source_format="free") + "\n" |
|
fstr = varstr + "\n" + fstr |
|
if self.fdtype == "real": |
|
|
|
fstr = fstr.replace("d0", "") |
|
return self.full_f_code(fstr) |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
from .parser import parse_prefix_to_sympy, sympy_to_dag |
|
|
|
prefs = "add mul div INT+ 1 INT+ 5 x mul div INT+ 1 INT+ 5 mul x tan pow x INT+ 2".split(" ") |
|
exp = parse_prefix_to_sympy(prefs) |
|
impl = Implementor(exp, dtype="float") |
|
|
|
print("DAG C:") |
|
print(impl.dag_to_c_impl()) |
|
print("DAG Fortran:") |
|
print(impl.dag_to_fortran_impl()) |
|
print("CSE C:") |
|
print(impl.sympy_cse_c_impl()) |
|
print("CSE Fortran:") |
|
print(impl.sympy_cse_fortran_impl()) |
|
|