Source code for modelparameters.codegeneration

# Copyright (C) 2012 Johan Hake
#
# This file is part of ModelParameters.
#
# ModelParameters is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# ModelParameters is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with ModelParameters. If not, see <http://www.gnu.org/licenses/>.
# System imports
import re

from . import sympy as sp
from .logger import error
from .sympy.core.function import AppliedUndef as _AppliedUndef
from .sympy.printing import StrPrinter as _StrPrinter
from .sympy.printing.latex import latex as _sympy_latex
from .sympy.printing.latex import LatexPrinter as _LatexPrinter
from .sympy.printing.precedence import precedence as _precedence
from .utils import check_arg as _check_arg
from .utils import scalars as _scalars


_order = "none"


# A collection of language specific keywords
_cpp_keywords = [
    "auto",
    "const",
    "double",
    "float",
    "int",
    "short",
    "struct",
    "break",
    "continue",
    "else",
    "for",
    "long",
    "signed",
    "switch",
    "case",
    "default",
    "enum",
    "goto",
    "register",
    "sizeof",
    "typedef",
    "char",
    "do",
    "extern",
    "if",
    "return",
    "static",
    "union",
    "while",
    "asm",
    "dynamic_cast",
    "namespace",
    "reinterpret_cast",
    "try",
    "bool",
    "explicit",
    "new",
    "static_cast",
    "typeid",
    "volatile",
    "catch",
    "operator",
    "template",
    "typename",
    "class",
    "friend",
    "private",
    "this",
    "using",
    "const_cast",
    "inline",
    "public",
    "throw",
    "virtual",
    "delete",
    "mutable",
    "protected",
    "wchar_t",
    "or",
    "and",
    "xor",
    "not",
    "unsigned",
    "void",
]

_python_keywords = [
    "and",
    "del",
    "from",
    "not",
    "while",
    "as",
    "elif",
    "global",
    "or",
    "with",
    "assert",
    "else",
    "if",
    "pass",
    "yield",
    "break",
    "except",
    "import",
    "print",
    "class",
    "exec",
    "in",
    "raise",
    "continue",
    "finally",
    "is",
    "return",
    "def",
    "for",
    "lambda",
    "try",
]

_matlab_keywords = [
    "break",
    "case",
    "catch",
    "classdef",
    "continue",
    "else",
    "elseif",
    "end",
    "for",
    "function",
    "global",
    "if",
    "otherwise",
    "parfor",
    "persistent",
    "return",
    "spmd",
    "switch",
    "try",
    "while",
]

_julia_keywords = [
    "break",
    "continue",
    "if",
    "elseif",
    "else",
    "global",
    "@assert" "for",
    "while",
    "end",
    "function",
    "in",
    "using",
    "nothing",
]

_fortran_keywords = (
    [
        "assign",
        "backspace",
        "block data",
        "call",
        "close",
        "common",
        "continue",
        "data",
        "dimension",
        "do",
        "else",
        "else if",
        "end",
        "endfile",
        "endif",
        "entry",
        "equivalence",
        "external",
        "format",
        "function",
        "goto",
        "if",
        "implicit",
        "inquire",
        "intrinsic",
        "open",
        "parameter",
        "pause",
        "print",
        "program",
        "read",
        "return",
        "rewind",
        "rewrite",
        "save",
        "stop",
        "subroutine",
        "then",
        "write",
    ]
    + [
        "allocate",
        "allocatable",
        "case",
        "contains",
        "cycle",
        "deallocate",
        "elsewhere",
        "exit",
        "include",
        "interface",
        "intent",
        "module",
        "namelist",
        "nullify",
        "only",
        "operator",
        "optional",
        "pointer",
        "private",
        "procedure",
        "public",
        "result",
        "recursive",
        "select",
        "sequence",
        "target",
        "use",
        "while",
        "where",
    ]
    + ["forall", "pure"]
    + [
        "abstract",
        "associate",
        "asynchronous",
        "bind",
        "class",
        "deferred",
        "enum",
        "enumerator",
        "extends",
        "final",
        "flush",
        "generic",
        "import",
        "non_overridable",
        "nopass",
        "pass",
        "protected",
        "value",
        "volatile",
        "wait",
    ]
    + [
        "block",
        "codimension",
        "do concurrent",
        "contiguous",
        "critical",
        "error stop",
        "submodule",
        "sync all",
        "sync images",
        "sync memory",
        "lock",
        "unlock",
    ]
)

_all_keywords = set(
    _cpp_keywords
    + _python_keywords
    + _matlab_keywords
    + _fortran_keywords
    + _julia_keywords,
)


def _get_potence(value):
    import math

    exponent = int(math.log10(value)) / 3 * 3
    rest = _round2(float(value) / 10 ** exponent, 3)
    if rest < 1:
        exponent -= 3
        rest *= 1e3
    return rest, exponent


def _round2(x, n=0, sigs4n=1):
    """
    Return x rounded to the specified number of significant digits, n, as
    counted from the first non-zero digit.

    If n=0 (the default value for round2) then the magnitude of the
    number will be returned (e.g. round2(12) returns 10.0).

    If n<0 then x will be rounded to the nearest multiple of n which, by
    default, will be rounded to 1 digit (e.g. round2(1.23,-.28) will round
    1.23 to the nearest multiple of 0.3.

    Regardless of n, if x=0, 0 will be returned.
    """
    import math

    if x == 0:
        return x
    if n < 0:
        n = _round2(-n, sigs4n)
        return n * int(x / n + 0.5)
    if n == 0:
        return 10.0 ** (int(math.floor(math.log10(abs(x)))))
    return round(x, int(n) - 1 - int(math.floor(math.log10(abs(x)))))


def _coeff_isneg(a):
    """Return True if the leading Number is negative.

    Examples
    ========

    >>> from sympy.core.function import _coeff_isneg
    >>> from sympy import S, Symbol, oo, pi
    >>> _coeff_isneg(-3*pi)
    True
    >>> _coeff_isneg(S(3))
    False
    >>> _coeff_isneg(-oo)
    True
    >>> _coeff_isneg(Symbol('n', negative=True)) # coeff is 1
    False

    """

    if a.is_Mul:
        a = a.args[0]
    return a.is_Number and a.is_negative


_relational_map = {
    "==": "Eq",
    "!=": "Ne",
    "<": "Lt",
    "<=": "Le",
    ">": "Gt",
    ">=": "Ge",
}


_relational_map_matlab = {
    "==": "==",
    "!=": "~=",
    "<": "<",
    "<=": "<=",
    ">": ">",
    ">=": ">=",
}


def _print_Mul(self, expr):

    prec = _precedence(expr)

    if self.order not in ("old", "none"):
        args = expr.as_ordered_factors()
    else:
        # use make_args in case expr was something like -x -> x
        args = sp.Mul.make_args(expr)

    args = tuple(args)

    if _coeff_isneg(expr):
        # If negative and -1 is the first arg: remove it
        if args[0].is_integer and int(args[0]) == 1:
            args = args[1:]
        else:
            args = (-args[0],) + args[1:]
        sign = "-"
    else:
        sign = ""

        # If first argument is Mul we do not want to add a parentesize
        if isinstance(args[0], sp.Mul):
            prec -= 1

    a = []  # items in the numerator
    b = []  # items that are in the denominator (if any)

    # Gather args for numerator/denominator
    for item in args:
        if (
            item.is_commutative
            and item.is_Pow
            and item.exp.is_Rational
            and item.exp.is_negative
        ):
            if item.exp != -1:
                b.append(sp.Pow(item.base, -item.exp, evaluate=False))
            else:
                b.append(sp.Pow(item.base, -item.exp))
        elif item.is_Rational and item is not sp.S.Infinity:
            if item.p != 1:
                a.append(sp.Rational(item.p))
            if item.q != 1:
                b.append(sp.Rational(item.q))
        else:
            a.append(item)

    a = a or [sp.S.One]

    a_str = [self.parenthesize(x, prec) for x in a]
    b_str = [self.parenthesize(x, prec) for x in b]

    if len(b) == 0:
        return sign + "*".join(a_str)
    elif len(b) == 1:
        if len(a) == 1 and not (a[0].is_Atom or a[0].is_Add):
            return sign + f"{a_str[0]}/" + "*".join(b_str)
        else:
            return sign + "*".join(a_str) + f"/{b_str[0]}"
    else:
        return sign + "*".join(a_str) + f"/({'*'.join(b_str)})"


# Patch sympy print Function
_old_print_Function = _StrPrinter._print_Function


def _print_Function(self, expr):
    if isinstance(expr, _AppliedUndef):
        return expr.func.__name__
    return _old_print_Function(self, expr)


# _StrPrinter._print_Function = _print_Function

_unit_template = re.compile(r"([a-zA-Z]+\*\*[\-0-9]+|[a-zA-Z]+)")


[docs]def latex_unit(unit): """ Return sympified and LaTeX-formatted string describing given unit. E.g.: >>> LatexCodeGenerator.format_unit("m/s**2") '\\mathrm{\\frac{m}{s^{2}}}' """ _check_arg(unit, str) if unit == "1": return "" atomic_units = [] for unit in re.findall(_unit_template, unit): micro = False exp = 0 # Check for usage of micro if "u" == unit[0]: unit = unit[1:] micro = True # Check for exponent if "**" in unit: unit, exp = unit.split("**") # Wrap text in mathrm unit = f"\\mathrm{{{unit}}}" if exp: unit += f"^{{{exp}}}" if micro: unit = "\\mu" + unit atomic_units.append(unit) return "\\,".join(atomic_units)
class _CustomPythonPrinter(_StrPrinter): def __init__(self, namespace=""): assert namespace in ["", "math", "np", "numpy", "ufl"] self._namespace = namespace if not namespace else namespace + "." _StrPrinter.__init__(self, settings=dict(order=_order)) def _print_Mod(self, expr): return f"({expr.args[0]} % {expr.args[1]})" # Why is this not called! def _print_Log(self, expr): if self._namespace == "ufl.": return f"{self._namespace}ln({self._print(expr.base)})" else: return f"{self._namespace}log({self._print(expr.base)})" def _print_Abs(self, expr): if self._namespace == "math.": return f"{self._namespace}fabs({self.stringify(expr.args, ', ')})" elif self._namespace == "ufl.": return f"abs({self.stringify(expr.args, ', ')})" else: return f"{self._namespace}abs({self.stringify(expr.args, ', ')})" # def _print_One(self, expr): # return "1.0" # def _print_Zero(self, expr): # return "0.0" def _print_Float(self, expr): # If not finite we use parent printer if expr.is_zero: return "0" if not expr.is_finite: return _StrPrinter._print_Float(self, expr) return str(float(expr)) # def _print_Integer(self, expr): # return str(expr.p) + ".0" def _print_Derivative(self, expr): if not isinstance(expr.args[1], (_AppliedUndef, sp.Symbol)): error( "Can only print Derivative code with a single dependent " "variabe. Got: {0}".format(sympycode(expr.args[1])), ) if isinstance(expr.args[0], _AppliedUndef): return "d%s_d%s" % ( expr.args[0].func.__name__, "_".join(self._print(arg) for arg in expr.args[1:]), ) return _StrPrinter._print_Derivative(self, expr) def _print_Subs(self, expr): # Execute subsitution orig_expr = expr.expr subs = dict((key, value) for key, value in zip(expr.variables, expr.point)) return self._print(orig_expr.xreplace(subs)) # def _print_NegativeOne(self, expr): # return "-1.0" def _print_Sqrt(self, expr): return f"{self._namespace}sqrt({self._print(expr.args[0])})" def _print_Relational(self, expr): return "{0}({1}, {2})".format( _relational_map[expr.rel_op], self._print(expr.lhs), self._print(expr.rhs), ) def _print_Piecewise(self, expr): result = "" num_par = 0 for e, c in expr.args[:-1]: num_par += 1 result += f"Conditional({self._print(c)}, {self._print(e)}, " last_line = self._print(expr.args[-1].expr) + ")" * num_par return result + last_line def _print_And(self, expr): if self._namespace == "ufl.": if len(expr.args) != 2: error("UFL does not support more than 2 operands to And") return f"ufl.And({self._print(expr.args[0])}, {self._print(expr.args[1])})" return f"And({', '.join(self._print(arg) for arg in expr.args[::-1])})" def _print_Or(self, expr): if self._namespace == "ufl.": if len(expr.args) != 2: error("UFL does not support more than 2 operands to Or") return f"ufl.Or({self._print(expr.args[0])}, {self._print(expr.args[1])})" return f"Or({', '.join(self._print(arg) for arg in expr.args[::-1])})" def _print_Pow(self, expr, rational=False): PREC = _precedence(expr) if expr.exp.is_integer and int(expr.exp) == 1: return self.parenthesize(expr.base, PREC) if expr.exp is sp.S.NegativeOne: return f"1.0/{self.parenthesize(expr.base, PREC)}" if expr.exp.is_integer and int(expr.exp) in [2, 3]: return "({0})".format( "*".join( self.parenthesize(expr.base, PREC) for i in range(int(expr.exp)) ), ) if expr.exp.is_integer and int(expr.exp) in [-2, -3]: return "1.0/({0})".format( "*".join( self.parenthesize(expr.base, PREC) for i in range(-int(expr.exp)) ), ) if expr.exp is sp.S.Half and not rational: return f"{self._namespace}sqrt({self._print(expr.base)})" if expr.exp == -0.5: return f"1/{self._namespace}sqrt({self._print(expr.base)})" if self._namespace == "ufl.": return "{0}elem_pow({1}, {2})".format( self._namespace, self._print(expr.base), self._print(expr.exp), ) if self._namespace in ["np.", "numpy."]: return "{0}power({1}, {2})".format( self._namespace, self._print(expr.base), self._print(expr.exp), ) return "{0}pow({1}, {2})".format( self._namespace, self._print(expr.base), self._print(expr.exp), ) _print_Function = _print_Function _print_Mul = _print_Mul class _CustomPythonCodePrinter(_CustomPythonPrinter): def _print_sign(self, expr): if self._namespace == "ufl.": return f"{self._namespace}sign({self._print(expr.args[0])})" elif self._namespace in ["math.", "numpy.", "np."]: return f"{self._namespace}copysign(1.0, {self._print(expr.args[0])})" return f"sign({self._print(expr.args[0])})" def _print_Mod(self, expr): if self._namespace == "math.": return f"{self._namespace}fmod({self.stringify(expr.args, ', ')})" else: return f"{self._namespace}mod({self.stringify(expr.args, ', ')})" def _print_Min(self, expr): if self._namespace == "ufl.": return f"ufl.{expr.func.__name__}({self.stringify(expr.args, ', ')})" return f"{expr.func.__name__.lower()}({self.stringify(expr.args, ', ')})" def _print_Max(self, expr): if self._namespace == "ufl.": return f"ufl.{expr.func.__name__}({self.stringify(expr.args, ', ')})" return f"{expr.func.__name__.lower()}({self.stringify(expr.args, ', ')})" def _print_Function(self, expr): # print expr.func.__name__, expr.args func_name = expr.func.__name__ if isinstance(expr, _AppliedUndef): return func_name elif func_name == "ceiling": return f"{self._namespace}ceil({self.stringify(expr.args, ', ')})" elif func_name == "log": if self._namespace == "ufl.": return f"{self._namespace}ln({self._print(expr.args[0])})" else: return f"{self._namespace}log({self._print(expr.args[0])})" else: return "{0}{1}".format( self._namespace, func_name.lower() + "({0})".format(self.stringify(expr.args, ", ")), ) def _print_re(self, expr): assert len(expr.args) == 1 return f"({self._print(expr.args[0])}).real" def _print_im(self, expr): assert len(expr.args) == 1 return f"({self._print(expr.args[0])}).imag" def _print_Piecewise(self, expr): result = "" num_par = 0 if self._namespace == "ufl.": for e, c in expr.args[:-1]: num_par += 1 result += f"ufl.conditional({self._print(c)}, {self._print(e)}, " elif self._namespace in ("numpy.", "np."): for e, c in expr.args[:-1]: num_par += 1 result += self._namespace result += f"where({self._print(c)}, {self._print(e)}, " else: cond_str = "{0}" for e, c in expr.args[:-1]: num_par += 1 result += "({0} if {1} else ".format( self._print(e), cond_str.format(self._print(c)), ) last_line = self._print(expr.args[-1].expr) + ")" * num_par return result + last_line def _print_Relational(self, expr): if self._namespace == "ufl.": return "ufl.{0}({1}, {2})".format( _relational_map[expr.rel_op].lower(), self._print(expr.lhs), self._print(expr.rhs), ) return "{0} {1} {2}".format( self.parenthesize(expr.lhs, _precedence(expr)), expr.rel_op, self.parenthesize(expr.rhs, _precedence(expr)), ) def _print_Pi(self, expr=None): return f"{self._namespace}pi" def _print_And(self, expr): PREC = _precedence(expr) if self._namespace == "ufl.": if len(expr.args) != 2: error("UFL does not support more than 2 operands to And") return f"ufl.And({self._print(expr.args[0])}, {self._print(expr.args[1])})" elif self._namespace in ("numpy.", "np."): result = self._namespace if len(expr.args) == 2: result += "logical_and({0}, {1})".format( self._print(expr.args[0]), self._print(expr.args[1]), ) else: result += "logical_and({0}, {1})".format( self._print(expr.args[0]), self._print(sp.And(*expr.args[1:])), ) return result return " and ".join(self.parenthesize(arg, PREC) for arg in expr.args[::-1]) return "{0} and {1}".format( self.parenthesize(expr.args[0], PREC), self.parenthesize(expr.args[1], PREC), ) def _print_Or(self, expr): PREC = _precedence(expr) if self._namespace == "ufl.": if len(expr.args) != 2: error("UFL does not support more than 2 operands to Or") return f"ufl.Or({self._print(expr.args[0])}, {self._print(expr.args[1])})" elif self._namespace in ("numpy.", "np."): result = self._namespace if len(expr.args) == 2: result += "logical_or({0}, {1})".format( self._print(expr.args[0]), self._print(expr.args[1]), ) else: result += "logical_or({0}, {1})".format( self._print(expr.args[0]), self._print(sp.Or(*expr.args[1:])), ) return result return " or ".join(self.parenthesize(arg, PREC) for arg in expr.args[::-1]) return "{0} or {1}".format( self.parenthesize(expr.args[0], PREC), self.parenthesize(expr.args[1], PREC), ) class _CustomCCodePrinter(_StrPrinter): """ Overload some ccode generation """ def __init__(self, cpp=False, float_precision="double", **settings): super(_CustomCCodePrinter, self).__init__(settings=settings) self._prefix = "std::" if cpp else "" # for single precision we need to suffix float literals with "f" # and we also need to suffix math library function calls with "f" self._math_suffix = "" if float_precision == "double" else "f" self._float_postfix = "" if float_precision == "double" else "f" def _print_math_function(self, function_name, arguments_str): # add the appropriate namespace and suffix for math functions return f"{self._prefix}{function_name}{self._math_suffix}({arguments_str})" def _print_Relational(self, expr): return "{0} {1} {2}".format( self.parenthesize(expr.lhs, _precedence(expr)), _relational_map_matlab[expr.rel_op], self.parenthesize(expr.rhs, _precedence(expr)), ) return "{0}({1}, {2})".format( _relational_map_matlab[expr.rel_op], self._print(expr.lhs), self._print(expr.rhs), ) def _print_Float(self, expr): f_str = _StrPrinter._print_Float(self, expr) return f_str + self._float_postfix def _print_One(self, expr): return "1." + self._float_postfix def _print_Zero(self, expr): return "0." + self._float_postfix def _print_Integer(self, expr): return f"{expr.p}.{self._float_postfix}" def _print_NegativeOne(self, expr): return "-1." + self._float_postfix def _print_Rational(self, expr): return "{0}.{2}/{1}.{2}".format(expr.p, expr.q, self._float_postfix) def _print_Min(self, expr): "fmin and fmax is not contained in std namespace untill -ansi g++ 4.7" return self._print_math_function("fmin", self.stringify(expr.args, ", ")) def _print_Max(self, expr): "fmin and fmax is not contained in std namespace untill -ansi g++ 4.7" return self._print_math_function("fmax", self.stringify(expr.args, ", ")) def _print_Ceiling(self, expr): return self._print_math_function("ceil", self.stringify(expr.args, ", ")) def _print_Abs(self, expr): return self._print_math_function("fabs", self.stringify(expr.args, ", ")) def _print_Mod(self, expr): return self._print_math_function("fmod", self.stringify(expr.args, ", ")) def _print_Piecewise(self, expr): result = "" for e, c in expr.args[:-1]: result += f"({self._print(c)} ? {self._print(e)} : " last_line = f"{self._print(expr.args[-1].expr)})" return result + last_line def _print_Add(self, expr): args = expr.args # Special treatment of (exp(a) - 1), where we should use expm1 if len(args) == 2: a, b = args if type(b) == sp.exp: a, b = b, a if type(a) is sp.exp and type(b) is sp.numbers.NegativeOne: return self._print_math_function("expm1", self.stringify(a.args, ", ")) return super()._print_Add(expr) def _print_Function(self, expr): # print expr.func.__name__, expr.args if isinstance(expr, _AppliedUndef): return expr.func.__name__ # add special case for math functions math_functions_with_float_variant = [ "fabs", "fmod", "remainder", "remquo", "fma", "fmax", "fmin", "fdim", "exp", "exp2", "expm1", "log", "log2", "log10", "log1p", "ilogb", "logb", "hypot", "cbrt", "sqrt", "pow", "sin", "cos", "tan", "asin", "acos", "atan", "atan2", "sinh", "cosh", "tanh", "asinh", "acosh", "atanh", "erf", "erfc", "lgamma", "tgamma", "ceil", "floor", "trunc", "round", "nearbyint", "rint", ] if expr.func.__name__.lower() in math_functions_with_float_variant: return self._print_math_function( expr.func.__name__.lower(), self.stringify(expr.args, ", "), ) return ( f"{self._prefix}" + expr.func.__name__.lower() + f"({self.stringify(expr.args, ', ')})" ) def _print_Subs(self, expr): # Execute subsitution orig_expr = expr.expr subs = dict((key, value) for key, value in zip(expr.variables, expr.point)) return self._print(orig_expr.xreplace(subs)) def _print_Derivative(self, expr): if not isinstance(expr.args[1], (_AppliedUndef, sp.Symbol)): error( "Can only print Derivative code with a single dependent " "variabe. Got: {0}".format(sympycode(expr.args[1])), ) if isinstance(expr.args[0], _AppliedUndef): return "d%s_d%s" % ( expr.args[0].func.__name__, "_".join(self._print(arg) for arg in expr.args[1:]), ) return _StrPrinter._print_Derivative(self, expr) def _print_Pow(self, expr, rational=False): PREC = _precedence(expr) if expr.exp.is_integer and int(expr.exp) == 1: return self.parenthesize(expr.base, PREC) if expr.exp is sp.S.NegativeOne: return f"1.0{self._float_postfix}/{self.parenthesize(expr.base, PREC)}" if expr.exp.is_integer and int(expr.exp) in [2, 3]: return "({0})".format( "*".join( self.parenthesize(expr.base, PREC) for i in range(int(expr.exp)) ), ) if expr.exp.is_integer and int(expr.exp) == 4: # Use parentheses strategically to facilitate expression reuse # pow(a, 4) = pow(a, 2) * pow(a, 2) # pow(a, 2) = a * a return "(({0})*({0}))".format( "({0})*({0})".format(self.parenthesize(expr.base, PREC)), ) if expr.exp.is_integer and int(expr.exp) in [-2, -3]: return "1.0{1}/({0})".format( "*".join( self.parenthesize(expr.base, PREC) for i in range(-int(expr.exp)) ), self._float_postfix, ) if expr.exp is sp.S.Half and not rational: return self._print_math_function("sqrt", self._print(expr.base)) if expr.exp == -0.5: return f"1/{self._print_math_function('sqrt', self._print(expr.base))}" return self._print_math_function( "pow", f"{self._print(expr.base)}, {self._print(expr.exp)}", ) def _print_sign(self, expr): return f"{self._prefix}copysign(1.0, {self._print(expr.args[0])})" def _print_Pi(self, expr=None): return "M_PI" def _print_And(self, expr): PREC = _precedence(expr) return " && ".join(self.parenthesize(arg, PREC) for arg in expr.args[::-1]) def _print_Or(self, expr): PREC = _precedence(expr) return " || ".join(self.parenthesize(arg, PREC) for arg in expr.args[::-1]) def _print_re(self, expr): assert len(expr.args) == 1 return f"{self._prefix}creal({self._print(expr.args[0])})" def _print_im(self, expr): assert len(expr.args) == 1 return f"{self._prefix}cimag({self._print(expr.args[0])})" def _print_Symbol(self, expr): if expr.name == "I": return "I_" return expr.name _print_Mul = _print_Mul class _CustomMatlabCodePrinter(_StrPrinter): """ Overload some ccode generation """ def __init__(self, **settings): super(_CustomMatlabCodePrinter, self).__init__(settings=settings) def _print_Float(self, expr): # If not finite we use parent printer if expr.is_zero: return "0" if not expr.is_finite: return _StrPrinter._print_Float(self, expr) return str(float(expr)) def _print_Min(self, expr): return f"min({self.stringify(expr.args, ', ')})" def _print_Max(self, expr): return f"max({self.stringify(expr.args, ', ')})" def _print_Ceiling(self, expr): return f"ceil({self.stringify(expr.args, ', ')})" def _print_Piecewise(self, expr): result = "" for e, c in expr.args[:-1]: result += f"(({self._print(c)})*({self._print(e)}) + ~({self._print(c)})*" last_line = f"({self._print(expr.args[-1].expr)}))" return result + last_line def _print_Function(self, expr): # print expr.func.__name__, expr.args if isinstance(expr, _AppliedUndef): return expr.func.__name__ return f"{expr.func.__name__.lower()}({self.stringify(expr.args, ', ')})" def _print_Pow(self, expr): PREC = _precedence(expr) if expr.exp.is_integer and int(expr.exp) == 1: return self.parenthesize(expr.base, PREC) if expr.exp is sp.S.NegativeOne: return f"1.0/{self.parenthesize(expr.base, PREC)}" if expr.exp == 0.5: return f"sqrt({self._print(expr.base)})" # FIXME: Fix paranthesises return "{0}^{1}".format( self.parenthesize(expr.base, PREC), self.parenthesize(expr.exp, PREC), ) def _print_And(self, expr): PREC = _precedence(expr) return " & ".join(self.parenthesize(arg, PREC) for arg in expr.args[::-1]) def _print_Not(self, expr): PREC = _precedence(expr) return "~" + self.parenthesize(expr.args[0], PREC) def _print_Relational(self, expr): return "{0} {1} {2}".format( self.parenthesize(expr.lhs, _precedence(expr)), _relational_map_matlab[expr.rel_op], self.parenthesize(expr.rhs, _precedence(expr)), ) return "{0}({1}, {2})".format( _relational_map_matlab[expr.rel_op], self._print(expr.lhs), self._print(expr.rhs), ) def _print_Or(self, expr): PREC = _precedence(expr) return " | ".join(self.parenthesize(arg, PREC) for arg in expr.args[::-1]) def _print_re(self, expr): assert len(expr.args) == 1 return f"real({self._print(expr.args[0])})" def _print_im(self, expr): assert len(expr.args) == 1 return f"imag({self._print(expr.args[0])})" _print_Mul = _print_Mul class _CustomJuliaCodePrinter(_StrPrinter): """ Overload some ccode generation """ def __init__(self, **settings): super(_CustomJuliaCodePrinter, self).__init__(settings=settings) def _print_Float(self, expr): # If not finite we use parent printer if expr.is_zero: return "0" if not expr.is_finite: return _StrPrinter._print_Float(self, expr) return str(float(expr)) def _print_Min(self, expr): return f"min({self.stringify(expr.args, ', ')})" def _print_Max(self, expr): return f"max({self.stringify(expr.args, ', ')})" def _print_Ceiling(self, expr): return f"ceil({self.stringify(expr.args, ', ')})" def _print_Piecewise(self, expr): result = "" for e, c in expr.args[:-1]: result += f"({self._print(c)} ? {self._print(e)} : " last_line = f"{self._print(expr.args[-1].expr)})" return result + last_line def _print_Function(self, expr): # print expr.func.__name__, expr.args if isinstance(expr, _AppliedUndef): return expr.func.__name__ return f"{expr.func.__name__.lower()}({self.stringify(expr.args, ', ')})" def _print_Pow(self, expr): PREC = _precedence(expr) if expr.exp.is_integer and int(expr.exp) == 1: return self.parenthesize(expr.base, PREC) if expr.exp is sp.S.NegativeOne: return f"1.0/{self.parenthesize(expr.base, PREC)}" if expr.exp == 0.5: return f"sqrt({self._print(expr.base)})" # FIXME: Fix paranthesises return "{0}^{1}".format( self.parenthesize(expr.base, PREC), self.parenthesize(expr.exp, PREC), ) def _print_And(self, expr): PREC = _precedence(expr) return " && ".join(self.parenthesize(arg, PREC) for arg in expr.args[::-1]) def _print_Not(self, expr): PREC = _precedence(expr) return "!" + self.parenthesize(expr.args[0], PREC) def _print_Relational(self, expr): return "{0} {1} {2}".format( self.parenthesize(expr.lhs, _precedence(expr)), _relational_map_matlab[expr.rel_op], self.parenthesize(expr.rhs, _precedence(expr)), ) return "{0}({1}, {2})".format( _relational_map_matlab[expr.rel_op], self._print(expr.lhs), self._print(expr.rhs), ) def _print_Or(self, expr): PREC = _precedence(expr) return " || ".join(self.parenthesize(arg, PREC) for arg in expr.args[::-1]) def _print_re(self, expr): assert len(expr.args) == 1 return f"real({self._print(expr.args[0])})" def _print_im(self, expr): assert len(expr.args) == 1 return f"imag({self._print(expr.args[0])})" _print_Mul = _print_Mul class _CustomLatexPrinter(_LatexPrinter): @staticmethod def _number_to_latex(value): if value < 0: sign = "-" value = -value else: sign = "" if abs(value) < 1e-32: rest, exponent = 0.0, 0 else: rest, exponent = _get_potence(value) # If formating 0.322 if exponent == -3 and int(rest) / 100 > 0: exponent = 0 rest = float(rest) / 1000 # Format rest if rest >= 100: form = "%d" rest = int(rest) elif rest >= 10: if rest % 1 > 0: form = "%.1f" else: form = "%d" rest = int(rest) else: if rest % 1 > 0: if (rest * 10) % 1 > 0: form = "%.2f" else: form = "%.1f" else: form = "%d" rest = int(rest) if exponent == 0: return sign + form % rest return r"%s\!\times\!10 ^{%d}" % (sign + form % rest, exponent) def _needs_brackets(self, expr): """ Returns True if the expression needs to be wrapped in brackets when printed, False otherwise. For example: a + b => True; a => False; 10 => False; -10 => True. """ return not ( (expr.is_Integer and expr.is_nonnegative) or (expr.is_Atom and expr is not sp.S.NegativeOne) or (isinstance(expr, _AppliedUndef) and expr is not sp.S.NegativeOne) ) def _print_Integer(self, expr): return self._print_Float(expr.evalf()) def _print_Float(self, expr): # If not finite we use parent printer if expr.is_zero: return "0" if not expr.is_finite: return _LatexPrinter._print_Float(self, expr) return self._number_to_latex(expr.evalf()) def _print_Function(self, expr, *args, **kwargs): if isinstance(expr, _AppliedUndef): return self._print_Symbol(sp.Symbol(expr.func.__name__)) return expr.func.__name__ return _LatexPrinter._print_Function(self, expr, *args, **kwargs) def _print_Add(self, expr): terms = list(expr.args) tex = self._print(terms[0]) for term in terms[1:]: out = self._print(term) if out and out[0] != "-": tex += " +" tex += " " + out return tex def _print_Mul(self, expr): coeff, _ = expr.as_coeff_Mul() if self.order not in ("old", "none"): args = expr.as_ordered_factors() else: # use make_args in case expr was something like -x -> x args = sp.Mul.make_args(expr) args = tuple(args) if _coeff_isneg(expr): # If negative and -1 is the first arg: remove it if args[0].is_integer and int(args[0]) == 1: args = args[1:] else: args = (-args[0],) + args[1:] tex = "- " else: tex = "" expr = sp.Mul(*args) from .sympy.simplify import fraction numer, denom = fraction(expr, exact=True) separator = self._settings["mul_symbol_latex"] numbersep = self._settings["mul_symbol_latex_numbers"] def convert(expr): # if expr is 1/1 if ( expr.is_Pow and expr.exp.is_Rational and expr.exp.is_negative and expr.base is sp.S.One ): expr = sp.S.One if not expr.is_Mul: return str(self._print(expr)) else: _tex = last_term_tex = "" if self.order not in ("old", "none"): args = expr.as_ordered_factors() else: args = expr.args for i, term in enumerate(args): term_tex = self._print(term) if self._needs_mul_brackets(term, last=(i == len(args) - 1)): term_tex = r"\left(%s\right)" % term_tex if re.search("[0-9][} ]*$", last_term_tex) and re.match( "[{ ]*[-+0-9]", term_tex, ): # between two numbers _tex += numbersep elif _tex: _tex += separator _tex += term_tex last_term_tex = term_tex return _tex if denom is sp.S.One: tex += convert(numer) else: snumer = convert(numer) sdenom = convert(denom) ldenom = len(sdenom.split()) ratio = self._settings["long_frac_ratio"] if self._settings["fold_short_frac"] and ldenom <= 2 and "^" not in sdenom: # handle short fractions if self._needs_mul_brackets(numer, last=False): tex += r"\left(%s\right) / %s" % (snumer, sdenom) else: tex += r"%s / %s" % (snumer, sdenom) elif len(snumer.split()) > ratio * ldenom: # handle long fractions if self._needs_mul_brackets(numer, last=True): tex += r"\frac{1}{%s}%s\left(%s\right)" % ( sdenom, separator, snumer, ) elif numer.is_Mul: # split a long numerator a = sp.S.One b = sp.S.One for x in numer.args: if ( self._needs_mul_brackets(x, last=False) or len(convert(a * x).split()) > ratio * ldenom or (b.is_commutative is x.is_commutative is False) ): b *= x else: a *= x if self._needs_mul_brackets(b, last=True): tex += r"\frac{%s}{%s}%s\left(%s\right)" % ( convert(a), sdenom, separator, convert(b), ) else: tex += r"\frac{%s}{%s}%s%s" % ( convert(a), sdenom, separator, convert(b), ) else: tex += r"\frac{1}{%s}%s%s" % (sdenom, separator, snumer) else: tex += r"\frac{%s}{%s}" % (snumer, sdenom) return tex def _print_Pow(self, expr): # Treat x**Rational(1,n) as special case if expr.exp.is_Rational and abs(expr.exp.p) == 1 and expr.exp.q != 1: base = self._print(expr.base) expq = expr.exp.q if expq == 2: tex = r"\sqrt{%s}" % base elif self._settings["itex"]: tex = r"\root{%d}{%s}" % (expq, base) else: tex = r"\sqrt[%d]{%s}" % (expq, base) if expr.exp.is_negative: return r"\frac{1}{%s}" % tex else: return tex elif ( self._settings["fold_frac_powers"] and expr.exp.is_Rational and expr.exp.q != 1 ): base, p, q = self._print(expr.base), expr.exp.p, expr.exp.q if expr.base.is_Function: return self._print(expr.base, f"{p}/{q}") if self._needs_brackets(expr.base): return r"\left(%s\right)^{%s/%s}" % (base, p, q) return r"%s^{%s/%s}" % (base, p, q) elif expr.exp.is_Rational and expr.exp.is_negative and expr.base.is_commutative: # Things like 1/x return self._print_Mul(expr) else: if expr.base.is_Function and not isinstance(expr.base, _AppliedUndef): return self._print(expr.base, self._print(expr.exp)) else: if expr.is_commutative and expr.exp == -1: # solves issue 1030 # As Mul always simplify 1/x to x**-1 # The objective is achieved with this hack # first we get the latex for -1 * expr, # which is a Mul expression tex = self._print(sp.S.NegativeOne * expr).strip() # the result comes with a minus and a space, so we remove if tex[:1] == "-": return tex[1:].strip() if self._needs_brackets(expr.base): tex = r"\left(%s\right)^{%s}" else: tex = r"%s^{%s}" return tex % (self._print(expr.base), self._print(expr.exp)) # Different math namespace python printer _python_code_printer = { "": _CustomPythonCodePrinter( "", ), "np": _CustomPythonCodePrinter("np"), "numpy": _CustomPythonCodePrinter("numpy"), "math": _CustomPythonCodePrinter("math"), "ufl": _CustomPythonCodePrinter("ufl"), } # FIXME: What on earth is ordered used for?!? _ccode_printer = _CustomCCodePrinter(order=_order) _cppcode_printer = _CustomCCodePrinter(cpp=True, order=_order) _ccode_float_printer = _CustomCCodePrinter(float_precision="single", order=_order) _cppcode_float_printer = _CustomCCodePrinter( cpp=True, float_precision="single", order=_order, ) _sympy_printer = _CustomPythonPrinter() _matlab_printer = _CustomMatlabCodePrinter(order=_order) _julia_printer = _CustomJuliaCodePrinter(order=_order)
[docs]def ccode(expr, assign_to=None, float_precision="double"): """ Return a C-code representation of a sympy expression """ if float_precision == "double": ret = _ccode_printer.doprint(expr) else: ret = _ccode_float_printer.doprint(expr) if assign_to is None: return ret if assign_to == "I": assign_to = "I_" return f"{assign_to} = {ret}"
[docs]def cppcode(expr, assign_to=None, float_precision="double"): """ Return a C++-code representation of a sympy expression """ if float_precision == "double": ret = _cppcode_printer.doprint(expr) else: ret = _cppcode_float_printer.doprint(expr) if assign_to is None: return ret return f"{assign_to} = {ret}"
[docs]def pythoncode(expr, assign_to=None, namespace="math"): """ Return a Python-code representation of a sympy expression """ ret = _python_code_printer[namespace].doprint(expr) if assign_to is None: return ret return f"{assign_to} = {ret}"
[docs]def sympycode(expr, assign_to=None): ret = _sympy_printer.doprint(expr) if assign_to is None: return ret return f"{assign_to} = {ret}"
[docs]def matlabcode(expr, assign_to=None): ret = _matlab_printer.doprint(expr) if assign_to is None: return ret return f"{assign_to} = {ret}"
[docs]def juliacode(expr, assign_to=None): ret = _julia_printer.doprint(expr) if assign_to is None: return ret return f"{assign_to} = {ret}"
[docs]def latex(expr, **settings): settings["order"] = "none" if isinstance(expr, str): if expr in sp.__dict__: return expr else: expr = sp.sympify(expr) elif isinstance(expr, _scalars): expr = sp.sympify(expr) return _CustomLatexPrinter(settings).doprint(expr)
latex.__doc__ = _sympy_latex.__doc__ octavecode = matlabcode __all__ = [_name for _name in list(globals().keys()) if _name[0] != "_"]