"""
Utility functions for parsing expressions and statements.
"""
import re
from collections import Counter
import sympy
from sympy.printing.precedence import precedence
from sympy.printing.str import StrPrinter
from brian2.core.functions import DEFAULT_CONSTANTS, DEFAULT_FUNCTIONS, Function
from brian2.parsing.rendering import SympyNodeRenderer
from brian2.utils.caching import cached
[docs]def check_expression_for_multiple_stateful_functions(expr, variables):
identifiers = re.findall(r"\w+", expr)
# Don't bother counting if we don't have any duplicates in the first place
if len(identifiers) == len(set(identifiers)):
return
identifier_count = Counter(identifiers)
for identifier, count in identifier_count.items():
var = variables.get(identifier, None)
if count > 1 and isinstance(var, Function) and not var.stateless:
raise NotImplementedError(
f"The expression '{expr}' contains "
f"more than one call of {identifier}. This "
"is currently not supported since "
f"{identifier} is a stateful function and "
"its multiple calls might be "
"treated incorrectly (e.g."
"'rand() - rand()' could be "
" simplified to "
"'0.0')."
)
[docs]def str_to_sympy(expr, variables=None):
"""
Parses a string into a sympy expression. There are two reasons for not
using `sympify` directly: 1) sympify does a ``from sympy import *``,
adding all functions to its namespace. This leads to issues when trying to
use sympy function names as variable names. For example, both ``beta`` and
``factor`` -- quite reasonable names for variables -- are sympy functions,
using them as variables would lead to a parsing error. 2) We want to use
a common syntax across expressions and statements, e.g. we want to allow
to use `and` (instead of `&`) and function names like `ceil` (instead of
`ceiling`).
Parameters
----------
expr : str
The string expression to parse.
variables : dict, optional
Dictionary mapping variable/function names in the expr to their
respective `Variable`/`Function` objects.
Returns
-------
s_expr
A sympy expression
Raises
------
SyntaxError
In case of any problems during parsing.
"""
if variables is None:
variables = {}
check_expression_for_multiple_stateful_functions(expr, variables)
# We do the actual transformation in a separate function that is cached
# If we cached `str_to_sympy` itself, it would also use the contents of the
# variables dictionary as the cache key, while it is only used for the check
# above and does not affect the translation to sympy
return _str_to_sympy(expr)
@cached
def _str_to_sympy(expr):
try:
s_expr = SympyNodeRenderer().render_expr(expr)
except (TypeError, ValueError, NameError) as ex:
raise SyntaxError(f"Error during evaluation of sympy expression '{expr}': {ex}")
return s_expr
[docs]class CustomSympyPrinter(StrPrinter):
"""
Printer that overrides the printing of some basic sympy objects. E.g.
print "a and b" instead of "And(a, b)".
"""
def _print_And(self, expr):
return " and ".join([f"({self.doprint(arg)})" for arg in expr.args])
def _print_Or(self, expr):
return " or ".join([f"({self.doprint(arg)})" for arg in expr.args])
def _print_Not(self, expr):
if len(expr.args) != 1:
raise AssertionError(f'"Not" with {len(expr.args)} arguments?')
return f"not ({self.doprint(expr.args[0])})"
def _print_Relational(self, expr):
return (
f"{self.parenthesize(expr.lhs, precedence(expr))} "
f"{self._relationals.get(expr.rel_op) or expr.rel_op} "
f"{self.parenthesize(expr.rhs, precedence(expr))}"
)
def _print_Function(self, expr):
# Special workaround for the int function
if expr.func.__name__ == "int_":
return f"int({self.stringify(expr.args, ', ')})"
elif expr.func.__name__ == "Mod":
return f"(({self.doprint(expr.args[0])})%({self.doprint(expr.args[1])}))"
else:
return f"{expr.func.__name__}({self.stringify(expr.args, ', ')})"
PRINTER = CustomSympyPrinter()
[docs]@cached
def sympy_to_str(sympy_expr):
"""
sympy_to_str(sympy_expr)
Converts a sympy expression into a string. This could be as easy as
``str(sympy_exp)`` but it is possible that the sympy expression contains
functions like ``Abs`` (for example, if an expression such as
``sqrt(x**2)`` appeared somewhere). We do want to re-translate ``Abs`` into
``abs`` in this case.
Parameters
----------
sympy_expr : sympy.core.expr.Expr
The expression that should be converted to a string.
Returns
str_expr : str
A string representing the sympy expression.
"""
# replace the standard functions by our names if necessary
replacements = {
f.sympy_func: sympy.Function(name)
for name, f in DEFAULT_FUNCTIONS.items()
if f.sympy_func is not None
and isinstance(f.sympy_func, sympy.FunctionClass)
and str(f.sympy_func) != name
}
# replace constants with our names as well
replacements.update(
{
c.sympy_obj: sympy.Symbol(name)
for name, c in DEFAULT_CONSTANTS.items()
if str(c.sympy_obj) != name
}
)
# Replace the placeholder argument by an empty symbol
replacements[sympy.Symbol("_placeholder_arg")] = sympy.Symbol("")
atoms = sympy_expr.atoms() | {f.func for f in sympy_expr.atoms(sympy.Function)}
for old, new in replacements.items():
if old in atoms:
sympy_expr = sympy_expr.subs(old, new)
expr = PRINTER.doprint(sympy_expr)
return expr
[docs]def expression_complexity(expr, complexity=None):
"""
Returns the complexity of an expression (either string or sympy)
The complexity is defined as 1 for each arithmetic operation except divide which is 2,
and all other operations are 20. This can be overridden using the complexity
argument.
Note: calling this on a statement rather than an expression is likely to lead to errors.
Parameters
----------
expr: `sympy.Expr` or str
The expression.
complexity: None or dict (optional)
A dictionary mapping expression names to their complexity, to overwrite default behaviour.
Returns
-------
complexity: int
The complexity of the expression.
"""
if isinstance(expr, str):
# we do this because sympy.count_ops doesn't handle inequalities (TODO: handle sympy as well str)
for op in ["<=", ">=", "==", "<", ">"]:
expr = expr.replace(op, "+")
# work around bug with rand() and randn() (TODO: improve this)
expr = expr.replace("rand()", "rand(0)")
expr = expr.replace("randn()", "randn(0)")
subs = {"ADD": 1, "DIV": 2, "MUL": 1, "SUB": 1}
if complexity is not None:
subs.update(complexity)
ops = sympy.count_ops(expr, visual=True)
for atom in ops.atoms():
if hasattr(atom, "name"):
subs[atom.name] = 20 # unknown operations assumed to have a large cost
return ops.evalf(subs=subs)