Source code for brian2.parsing.bast

'''
Brian AST representation

This is a standard Python AST representation with additional information added.
'''

import ast
import numpy
from __builtin__ import all as logical_all # defensive programming against numpy import

__all__ = ['brian_ast', 'BrianASTRenderer', 'dtype_hierarchy']


# This codifies the idea that operations involving e.g. boolean and integer will end up
# as integer. In general the output type will be the max of the hierarchy values here.
dtype_hierarchy = {'boolean': 0,
                   'integer': 1,
                   'float': 2,
                   }
# This is just so you can invert from number to string
for tc, i in dtype_hierarchy.items():
    dtype_hierarchy[i] = tc

[docs]def is_boolean(value): return isinstance(value, bool)
[docs]def is_integer(value): return isinstance(value, (int, numpy.integer))
[docs]def is_float(value): return isinstance(value, (float, numpy.float))
[docs]def brian_dtype_from_value(value): ''' Returns 'boolean', 'integer' or 'float' ''' if is_float(value): return 'float' elif is_integer(value): return 'integer' elif is_boolean(value): return 'boolean' raise TypeError("Unknown dtype for value "+str(value))
# The following functions are called very often during the optimisation process # so we don't use numpy.issubdtype but instead a precalculated list of all # standard types bool_dtype =numpy.dtype(numpy.bool)
[docs]def is_boolean_dtype(obj): return numpy.dtype(obj) is bool_dtype
integer_dtypes = {numpy.dtype(c) for c in numpy.typecodes['AllInteger']}
[docs]def is_integer_dtype(obj): return numpy.dtype(obj) in integer_dtypes
float_dtypes = {numpy.dtype(c) for c in numpy.typecodes['AllFloat']}
[docs]def is_float_dtype(obj): return numpy.dtype(obj) in float_dtypes
[docs]def brian_dtype_from_dtype(dtype): ''' Returns 'boolean', 'integer' or 'float' ''' if is_float_dtype(dtype): return 'float' elif is_integer_dtype(dtype): return 'integer' elif is_boolean_dtype(dtype): return 'boolean' raise TypeError("Unknown dtype: "+str(dtype))
[docs]def brian_ast(expr, variables): ''' Returns an AST tree representation with additional information Each node will be a standard Python ``ast`` node with the following additional attributes: ``dtype`` One of ``'boolean'``, ``'integer'`` or ``'float'``, referring to the data type of the value of this node. ``scalar`` Either ``True`` or ``False`` if the node uses any vector-valued variables. ``complexity`` An integer representation of the computational complexity of the node. This is a very rough representation used for things like ``2*(x+y)`` is less complex than ``2*x+2*y`` and ``exp(x)`` is more complex than ``2*x`` but shouldn't be relied on for fine distinctions between expressions. Parameters ---------- expr : str The expression to convert into an AST representation variables : dict The dictionary of `Variable` objects used in the expression. ''' node = ast.parse(expr, mode='eval').body renderer = BrianASTRenderer(variables) return renderer.render_node(node)
[docs]class BrianASTRenderer(object): ''' This class is modelled after `NodeRenderer` - see there for details. ''' def __init__(self, variables, copy_variables=True): if copy_variables: self.variables = variables.copy() else: self.variables = variables
[docs] def render_node(self, node): nodename = node.__class__.__name__ methname = 'render_'+nodename try: return getattr(self, methname)(node) except AttributeError: raise SyntaxError("Unknown syntax: " + nodename)
[docs] def render_NameConstant(self, node): if node.value is not True and node.value is not False: raise SyntaxError("Unknown NameConstant "+str(node.value)) # NameConstant only used for True and False and None, and we don't support None node.dtype = 'boolean' node.scalar = True node.complexity = 0 node.stateless = True return node
[docs] def render_Name(self, node): node.complexity = 0 if node.id=='True' or node.id=='False': node.dtype = 'boolean' node.scalar = True elif node.id in self.variables: var = self.variables[node.id] dtype = var.dtype node.dtype = brian_dtype_from_dtype(dtype) node.scalar = var.scalar else: # don't think we need to handle other names (pi, e, inf)? node.dtype = 'float' node.scalar = True # I think this assumption is OK, but not certain node.stateless = True return node
[docs] def render_Num(self, node): node.complexity = 0 node.dtype = brian_dtype_from_value(node.n) node.scalar = True node.stateless = True return node
[docs] def render_Call(self, node): if len(node.keywords): raise ValueError("Keyword arguments not supported.") elif getattr(node, 'starargs', None) is not None: raise ValueError("Variable number of arguments not supported") elif getattr(node, 'kwargs', None) is not None: raise ValueError("Keyword arguments not supported") node.args = [self.render_node(subnode) for subnode in node.args] node.dtype = 'float' # default dtype # Condition for scalarity of function call: stateless and arguments are scalar node.scalar = False if node.func.id in self.variables: funcvar = self.variables[node.func.id] # sometimes this attribute doesn't exist, if so assume it's not stateless node.stateless = getattr(funcvar, 'stateless', False) if node.stateless: node.scalar = logical_all(subnode.scalar for subnode in node.args) # check that argument types are valid node_arg_types = [subnode.dtype for subnode in node.args] for subnode, argtype in zip(node.args, funcvar._arg_types): if argtype!='any' and argtype!=subnode.dtype: raise TypeError("Function %s takes arguments with types %s but " "received %s" % (node.func.id, funcvar._arg_types, node_arg_types)) # compute return type return_type = funcvar._return_type if return_type=='highest': return_type = dtype_hierarchy[max(dtype_hierarchy[nat] for nat in node_arg_types)] node.dtype = return_type else: node.stateless = False # we leave node.func because it is an ast.Name object that doesn't have a dtype # TODO: variable complexity for function calls? node.complexity = 20+sum(subnode.complexity for subnode in node.args) return node
[docs] def render_BinOp(self, node): node.left = self.render_node(node.left) node.right = self.render_node(node.right) # TODO: we could capture some syntax errors here, e.g. bool+bool # captures, e.g. int+float->float newdtype = dtype_hierarchy[max(dtype_hierarchy[subnode.dtype] for subnode in [node.left, node.right])] node.dtype = newdtype node.scalar = node.left.scalar and node.right.scalar node.complexity = 1+node.left.complexity+node.right.complexity node.stateless = node.left.stateless and node.right.stateless return node
[docs] def render_BoolOp(self, node): node.values = [self.render_node(subnode) for subnode in node.values] node.dtype = 'boolean' for subnode in node.values: if subnode.dtype!='boolean': raise TypeError("Boolean operator acting on non-booleans") node.scalar = logical_all(subnode.scalar for subnode in node.values) node.complexity = 1+sum(subnode.complexity for subnode in node.values) node.stateless = logical_all(subnode.stateless for subnode in node.values) return node
[docs] def render_Compare(self, node): node.left = self.render_node(node.left) node.comparators = [self.render_node(subnode) for subnode in node.comparators] node.dtype = 'boolean' comparators = [node.left]+node.comparators node.scalar = logical_all(subnode.scalar for subnode in comparators) node.complexity = 1+sum(subnode.complexity for subnode in comparators) node.stateless = node.left.stateless and all(c.stateless for c in node.comparators) return node
[docs] def render_UnaryOp(self, node): node.operand = self.render_node(node.operand) node.dtype = node.operand.dtype if node.dtype=='boolean' and node.op.__class__.__name__ != 'Not': raise TypeError("Unary operator %s does not apply to boolean types" % node.op.__class__.__name__) node.scalar = node.operand.scalar node.complexity = 1+node.operand.complexity node.stateless = node.operand.stateless return node
if __name__=='__main__': import brian2 eqs = ''' x : 1 y : 1 (shared) a : integer b : boolean c : integer (shared) ''' expr = 'x<3.0+1.0' G = brian2.NeuronGroup(2, eqs) variables = {} variables.update(**brian2.DEFAULT_FUNCTIONS) variables.update(**brian2.DEFAULT_CONSTANTS) variables.update(**G.variables) node = brian_ast(expr, variables) print node.dtype, node.scalar, node.complexity