"""
Utility functions for parsing expressions and statements.
"""
import re
from collections import Counter
import sympy
from sympy.printing.precedence import precedence
from sympy.printing.str import StrPrinter
from brian2.core.functions import (DEFAULT_FUNCTIONS, DEFAULT_CONSTANTS,
Function)
from brian2.parsing.rendering import SympyNodeRenderer
from brian2.utils.caching import cached
[docs]def check_expression_for_multiple_stateful_functions(expr, variables):
identifiers = re.findall(r'\w+', expr)
# Don't bother counting if we don't have any duplicates in the first place
if len(identifiers) == len(set(identifiers)):
return
identifier_count = Counter(identifiers)
for identifier, count in identifier_count.items():
var = variables.get(identifier, None)
if count > 1 and isinstance(var, Function) and not var.stateless:
raise NotImplementedError(f"The expression '{expr}' contains "
f"more than one call of {identifier}. This "
f"is currently not supported since "
f"{identifier} is a stateful function and "
f"its multiple calls might be "
f"treated incorrectly (e.g."
f"'rand() - rand()' could be "
f" simplified to "
f"'0.0').")
[docs]def str_to_sympy(expr, variables=None):
"""
Parses a string into a sympy expression. There are two reasons for not
using `sympify` directly: 1) sympify does a ``from sympy import *``,
adding all functions to its namespace. This leads to issues when trying to
use sympy function names as variable names. For example, both ``beta`` and
``factor`` -- quite reasonable names for variables -- are sympy functions,
using them as variables would lead to a parsing error. 2) We want to use
a common syntax across expressions and statements, e.g. we want to allow
to use `and` (instead of `&`) and function names like `ceil` (instead of
`ceiling`).
Parameters
----------
expr : str
The string expression to parse.
variables : dict, optional
Dictionary mapping variable/function names in the expr to their
respective `Variable`/`Function` objects.
Returns
-------
s_expr
A sympy expression
Raises
------
SyntaxError
In case of any problems during parsing.
"""
if variables is None:
variables = {}
check_expression_for_multiple_stateful_functions(expr, variables)
# We do the actual transformation in a separate function that is cached
# If we cached `str_to_sympy` itself, it would also use the contents of the
# variables dictionary as the cache key, while it is only used for the check
# above and does not affect the translation to sympy
return _str_to_sympy(expr)
@cached
def _str_to_sympy(expr):
try:
s_expr = SympyNodeRenderer().render_expr(expr)
except (TypeError, ValueError, NameError) as ex:
raise SyntaxError(f"Error during evaluation of sympy expression "
f"'{expr}': {ex}")
return s_expr
[docs]class CustomSympyPrinter(StrPrinter):
"""
Printer that overrides the printing of some basic sympy objects. E.g.
print "a and b" instead of "And(a, b)".
"""
def _print_And(self, expr):
return ' and '.join([f'({self.doprint(arg)})' for arg in expr.args])
def _print_Or(self, expr):
return ' or '.join([f'({self.doprint(arg)})' for arg in expr.args])
def _print_Not(self, expr):
if len(expr.args) != 1:
raise AssertionError(f'"Not" with {len(expr.args)} arguments?')
return f'not ({self.doprint(expr.args[0])})'
def _print_Relational(self, expr):
return (f"{self.parenthesize(expr.lhs, precedence(expr))} "
f"{self._relationals.get(expr.rel_op) or expr.rel_op} "
f"{self.parenthesize(expr.rhs, precedence(expr))}")
def _print_Function(self, expr):
# Special workaround for the int function
if expr.func.__name__ == 'int_':
return f"int({self.stringify(expr.args, ', ')})"
elif expr.func.__name__ == 'Mod':
return f'(({self.doprint(expr.args[0])})%({self.doprint(expr.args[1])}))'
else:
return f"{expr.func.__name__}({self.stringify(expr.args, ', ')})"
PRINTER = CustomSympyPrinter()
[docs]@cached
def sympy_to_str(sympy_expr):
"""
sympy_to_str(sympy_expr)
Converts a sympy expression into a string. This could be as easy as
``str(sympy_exp)`` but it is possible that the sympy expression contains
functions like ``Abs`` (for example, if an expression such as
``sqrt(x**2)`` appeared somewhere). We do want to re-translate ``Abs`` into
``abs`` in this case.
Parameters
----------
sympy_expr : sympy.core.expr.Expr
The expression that should be converted to a string.
Returns
str_expr : str
A string representing the sympy expression.
"""
# replace the standard functions by our names if necessary
replacements = dict((f.sympy_func, sympy.Function(name)) for
name, f in DEFAULT_FUNCTIONS.items()
if f.sympy_func is not None and isinstance(f.sympy_func,
sympy.FunctionClass)
and str(f.sympy_func) != name)
# replace constants with our names as well
replacements.update(dict((c.sympy_obj, sympy.Symbol(name)) for
name, c in DEFAULT_CONSTANTS.items()
if str(c.sympy_obj) != name))
# Replace the placeholder argument by an empty symbol
replacements[sympy.Symbol('_placeholder_arg')] = sympy.Symbol('')
atoms = (sympy_expr.atoms() |
{f.func for f in sympy_expr.atoms(sympy.Function)})
for old, new in replacements.items():
if old in atoms:
sympy_expr = sympy_expr.subs(old, new)
expr = PRINTER.doprint(sympy_expr)
return expr
[docs]def expression_complexity(expr, complexity=None):
"""
Returns the complexity of an expression (either string or sympy)
The complexity is defined as 1 for each arithmetic operation except divide which is 2,
and all other operations are 20. This can be overridden using the complexity
argument.
Note: calling this on a statement rather than an expression is likely to lead to errors.
Parameters
----------
expr: `sympy.Expr` or str
The expression.
complexity: None or dict (optional)
A dictionary mapping expression names to their complexity, to overwrite default behaviour.
Returns
-------
complexity: int
The complexity of the expression.
"""
if isinstance(expr, str):
# we do this because sympy.count_ops doesn't handle inequalities (TODO: handle sympy as well str)
for op in ['<=', '>=', '==', '<', '>']:
expr = expr.replace(op, '+')
# work around bug with rand() and randn() (TODO: improve this)
expr = expr.replace('rand()', 'rand(0)')
expr = expr.replace('randn()', 'randn(0)')
subs = {'ADD': 1, 'DIV': 2, 'MUL': 1, 'SUB': 1}
if complexity is not None:
subs.update(complexity)
ops = sympy.count_ops(expr, visual=True)
for atom in ops.atoms():
if hasattr(atom, 'name'):
subs[atom.name] = 20 # unknown operations assumed to have a large cost
return ops.evalf(subs=subs)