Source code for brian2.parsing.bast

"""
Brian AST representation

This is a standard Python AST representation with additional information added.
"""

import ast
import weakref

import numpy

from brian2.utils.logger import get_logger

__all__ = ["brian_ast", "BrianASTRenderer", "dtype_hierarchy"]


logger = get_logger(__name__)

# This codifies the idea that operations involving e.g. boolean and integer will end up
# as integer. In general the output type will be the max of the hierarchy values here.
dtype_hierarchy = {
    "boolean": 0,
    "integer": 1,
    "float": 2,
}
# This is just so you can invert from number to string
for tc, i in dict(dtype_hierarchy).items():
    dtype_hierarchy[i] = tc


[docs] def is_boolean(value): return isinstance(value, bool)
[docs] def is_integer(value): return isinstance(value, (int, numpy.integer))
[docs] def is_float(value): return isinstance(value, (float, numpy.float32, numpy.float64))
[docs] def brian_dtype_from_value(value): """ Returns 'boolean', 'integer' or 'float' """ if is_float(value): return "float" elif is_integer(value): return "integer" elif is_boolean(value): return "boolean" raise TypeError(f"Unknown dtype for value {str(value)}")
# The following functions are called very often during the optimisation process # so we don't use numpy.issubdtype but instead a precalculated list of all # standard types bool_dtype = numpy.dtype(bool)
[docs] def is_boolean_dtype(obj): return numpy.dtype(obj) is bool_dtype
integer_dtypes = {numpy.dtype(c) for c in numpy.typecodes["AllInteger"]}
[docs] def is_integer_dtype(obj): return numpy.dtype(obj) in integer_dtypes
float_dtypes = {numpy.dtype(c) for c in numpy.typecodes["AllFloat"]}
[docs] def is_float_dtype(obj): return numpy.dtype(obj) in float_dtypes
[docs] def brian_dtype_from_dtype(dtype): """ Returns 'boolean', 'integer' or 'float' """ if is_float_dtype(dtype): return "float" elif is_integer_dtype(dtype): return "integer" elif is_boolean_dtype(dtype): return "boolean" raise TypeError(f"Unknown dtype: {str(dtype)}")
[docs] def brian_ast(expr, variables): """ Returns an AST tree representation with additional information Each node will be a standard Python ``ast`` node with the following additional attributes: ``dtype`` One of ``'boolean'``, ``'integer'`` or ``'float'``, referring to the data type of the value of this node. ``scalar`` Either ``True`` or ``False`` if the node uses any vector-valued variables. ``complexity`` An integer representation of the computational complexity of the node. This is a very rough representation used for things like ``2*(x+y)`` is less complex than ``2*x+2*y`` and ``exp(x)`` is more complex than ``2*x`` but shouldn't be relied on for fine distinctions between expressions. Parameters ---------- expr : str The expression to convert into an AST representation variables : dict The dictionary of `Variable` objects used in the expression. """ node = ast.parse(expr, mode="eval").body renderer = BrianASTRenderer(variables) return renderer.render_node(node)
[docs] class BrianASTRenderer: """ This class is modelled after `NodeRenderer` - see there for details. """ def __init__(self, variables, copy_variables=True): if copy_variables: self.variables = variables.copy() else: self.variables = variables
[docs] def render_node(self, node): nodename = node.__class__.__name__ methname = f"render_{nodename}" try: return getattr(self, methname)(node) except AttributeError: raise SyntaxError(f"Unknown syntax: {nodename}")
[docs] def render_NameConstant(self, node): if node.value is not True and node.value is not False: raise SyntaxError(f"Unknown NameConstant {str(node.value)}") # NameConstant only used for True and False and None, and we don't support None node.dtype = "boolean" node.scalar = True node.complexity = 0 node.stateless = True return node
[docs] def render_Name(self, node): node.complexity = 0 if node.id == "True" or node.id == "False": node.dtype = "boolean" node.scalar = True elif node.id in self.variables: var = self.variables[node.id] dtype = var.dtype node.dtype = brian_dtype_from_dtype(dtype) node.scalar = var.scalar else: # don't think we need to handle other names (pi, e, inf)? node.dtype = "float" node.scalar = True # I think this assumption is OK, but not certain node.stateless = True return node
[docs] def render_Num(self, node): node.complexity = 0 node.dtype = brian_dtype_from_value(node.value) node.scalar = True node.stateless = True return node
[docs] def render_Constant(self, node): # For literals in Python >= 3.8 if node.value is True or node.value is False or node.value is None: return self.render_NameConstant(node) else: return self.render_Num(node)
[docs] def render_Call(self, node): if len(node.keywords): raise ValueError("Keyword arguments not supported.") elif getattr(node, "starargs", None) is not None: raise ValueError("Variable number of arguments not supported") elif getattr(node, "kwargs", None) is not None: raise ValueError("Keyword arguments not supported") args = [] for subnode in node.args: subnode.parent = weakref.proxy(node) subnode = self.render_node(subnode) args.append(subnode) node.args = args node.dtype = "float" # default dtype # Condition for scalarity of function call: stateless and arguments are scalar node.scalar = False if node.func.id in self.variables: funcvar = self.variables[node.func.id] # sometimes this attribute doesn't exist, if so assume it's not stateless node.stateless = getattr(funcvar, "stateless", False) node.auto_vectorise = getattr(funcvar, "auto_vectorise", False) if node.stateless and not node.auto_vectorise: node.scalar = all(subnode.scalar for subnode in node.args) # check that argument types are valid node_arg_types = [subnode.dtype for subnode in node.args] for subnode, argtype in zip(node.args, funcvar._arg_types): if argtype != "any" and argtype != subnode.dtype: raise TypeError( f"Function '{node.func.id}' takes arguments with " f"types {funcvar._arg_types} but " f"received {node_arg_types}." ) # compute return type return_type = funcvar._return_type if return_type == "highest": return_type = dtype_hierarchy[ max(dtype_hierarchy[nat] for nat in node_arg_types) ] node.dtype = return_type else: node.stateless = False # we leave node.func because it is an ast.Name object that doesn't have a dtype # TODO: variable complexity for function calls? node.complexity = 20 + sum(subnode.complexity for subnode in node.args) return node
[docs] def render_BinOp(self, node): node.left.parent = weakref.proxy(node) node.right.parent = weakref.proxy(node) node.left = self.render_node(node.left) node.right = self.render_node(node.right) # TODO: we could capture some syntax errors here, e.g. bool+bool # captures, e.g. int+float->float newdtype = dtype_hierarchy[ max(dtype_hierarchy[subnode.dtype] for subnode in [node.left, node.right]) ] if node.op.__class__.__name__ == "Div": # Division turns integers into floating point values newdtype = "float" node.dtype = newdtype node.scalar = node.left.scalar and node.right.scalar node.complexity = 1 + node.left.complexity + node.right.complexity node.stateless = node.left.stateless and node.right.stateless return node
[docs] def render_BoolOp(self, node): values = [] for subnode in node.values: subnode.parent = node subnode = self.render_node(subnode) values.append(subnode) node.values = values node.dtype = "boolean" for subnode in node.values: if subnode.dtype != "boolean": raise TypeError("Boolean operator acting on non-booleans") node.scalar = all(subnode.scalar for subnode in node.values) node.complexity = 1 + sum(subnode.complexity for subnode in node.values) node.stateless = all(subnode.stateless for subnode in node.values) return node
[docs] def render_Compare(self, node): node.left = self.render_node(node.left) comparators = [] for subnode in node.comparators: subnode.parent = node subnode = self.render_node(subnode) comparators.append(subnode) node.comparators = comparators node.dtype = "boolean" comparators = [node.left] + node.comparators node.scalar = all(subnode.scalar for subnode in comparators) node.complexity = 1 + sum(subnode.complexity for subnode in comparators) node.stateless = node.left.stateless and all( c.stateless for c in node.comparators ) return node
[docs] def render_UnaryOp(self, node): node.operand.parent = node node.operand = self.render_node(node.operand) node.dtype = node.operand.dtype if node.dtype == "boolean" and node.op.__class__.__name__ != "Not": raise TypeError( f"Unary operator {node.op.__class__.__name__} does not apply to boolean" " types" ) node.scalar = node.operand.scalar node.complexity = 1 + node.operand.complexity node.stateless = node.operand.stateless return node