import ast
import sympy
from brian2.core.functions import DEFAULT_FUNCTIONS, DEFAULT_CONSTANTS
__all__ = ['NodeRenderer',
'NumpyNodeRenderer',
'CPPNodeRenderer',
'SympyNodeRenderer'
]
[docs]class NodeRenderer(object):
expression_ops = {
# BinOp
'Add': '+',
'Sub': '-',
'Mult': '*',
'Div': '/',
'FloorDiv': '//',
'Pow': '**',
'Mod': '%',
# Compare
'Lt': '<',
'LtE': '<=',
'Gt': '>',
'GtE': '>=',
'Eq': '==',
'NotEq': '!=',
# Unary ops
'Not': 'not',
'UAdd': '+',
'USub': '-',
# Bool ops
'And': 'and',
'Or': 'or',
# Augmented assign
'AugAdd': '+=',
'AugSub': '-=',
'AugMult': '*=',
'AugDiv': '/=',
'AugPow': '**=',
'AugMod': '%=',
}
def __init__(self, use_vectorisation_idx=True):
self.use_vectorisation_idx = use_vectorisation_idx
[docs] def render_expr(self, expr, strip=True):
if strip:
expr = expr.strip()
node = ast.parse(expr, mode='eval')
return self.render_node(node.body)
[docs] def render_code(self, code):
lines = []
for node in ast.parse(code).body:
lines.append(self.render_node(node))
return '\n'.join(lines)
[docs] def render_node(self, node):
nodename = node.__class__.__name__
methname = 'render_'+nodename
try:
return getattr(self, methname)(node)
except AttributeError:
raise SyntaxError("Unknown syntax: " + nodename)
[docs] def render_func(self, node):
return self.render_Name(node)
[docs] def render_NameConstant(self, node):
return str(node.value)
[docs] def render_Name(self, node):
return node.id
[docs] def render_Num(self, node):
return repr(node.n)
[docs] def render_Call(self, node):
if len(node.keywords):
raise ValueError("Keyword arguments not supported.")
elif getattr(node, 'starargs', None) is not None:
raise ValueError("Variable number of arguments not supported")
elif getattr(node, 'kwargs', None) is not None:
raise ValueError("Keyword arguments not supported")
if len(node.args) == 0 and self.use_vectorisation_idx:
# argument-less function call such as randn() are transformed into
# randn(_vectorisation_idx) -- this is important for Python code
# in particular, because it has to return an array of values.
return '%s(%s)' % (self.render_func(node.func),
'_vectorisation_idx')
else:
return '%s(%s)' % (self.render_func(node.func),
', '.join(self.render_node(arg) for arg in node.args))
[docs] def render_element_parentheses(self, node):
'''
Render an element with parentheses around it or leave them away for
numbers, names and function calls.
'''
if node.__class__.__name__ in ['Name', 'NameConstant']:
return self.render_node(node)
elif node.__class__.__name__ == 'Num' and node.n >= 0:
return self.render_node(node)
elif node.__class__.__name__ == 'Call':
return self.render_node(node)
else:
return '(%s)' % self.render_node(node)
[docs] def render_BinOp_parentheses(self, left, right, op):
# Use a simplified checking whether it is possible to omit parentheses:
# only omit parentheses for numbers, variable names or function calls.
# This means we still put needless parentheses because we ignore
# precedence rules, e.g. we write "3 + (4 * 5)" but at least we do
# not do "(3) + ((4) + (5))"
op_class = op.__class__.__name__
# Give a more useful error message when using bit-wise operators
if op_class in ['BitXor', 'BitAnd', 'BitOr']:
correction = {'BitXor': ('^', '**'),
'BitAnd': ('&', 'and'),
'BitOr': ('|', 'or')}.get(op_class)
raise SyntaxError('The operator "{}" is not supported, use "{}" '
'instead.'.format(correction[0], correction[1]))
return '%s %s %s' % (self.render_element_parentheses(left),
self.expression_ops[op_class],
self.render_element_parentheses(right))
[docs] def render_BinOp(self, node):
return self.render_BinOp_parentheses(node.left, node.right, node.op)
[docs] def render_BoolOp(self, node):
op = node.op
left = node.values[0]
remaining = node.values[1:]
while len(remaining):
right = remaining[0]
remaining = remaining[1:]
s = self.render_BinOp_parentheses(left, right, op)
op = self.expression_ops[node.op.__class__.__name__]
return (' '+op+' ').join('%s' % self.render_element_parentheses(v) for v in node.values)
[docs] def render_Compare(self, node):
if len(node.comparators)>1:
raise SyntaxError("Can only handle single comparisons like a<b not a<b<c")
return self.render_BinOp_parentheses(node.left, node.comparators[0], node.ops[0])
[docs] def render_UnaryOp(self, node):
return '%s %s' % (self.expression_ops[node.op.__class__.__name__],
self.render_element_parentheses(node.operand))
[docs] def render_Assign(self, node):
if len(node.targets)>1:
raise SyntaxError("Only support syntax like a=b not a=b=c")
return '%s = %s' % (self.render_node(node.targets[0]),
self.render_node(node.value))
[docs] def render_AugAssign(self, node):
target = node.target.id
rhs = self.render_node(node.value)
op = self.expression_ops['Aug'+node.op.__class__.__name__]
return '%s %s %s' % (target, op, rhs)
[docs]class NumpyNodeRenderer(NodeRenderer):
expression_ops = NodeRenderer.expression_ops.copy()
expression_ops.update({
# Unary ops
# We'll handle "not" explicitly below
# Bool ops
'And': '&',
'Or': '|',
})
[docs] def render_UnaryOp(self, node):
if node.op.__class__.__name__ == 'Not':
return 'logical_not(%s)' % self.render_node(node.operand)
else:
return NodeRenderer.render_UnaryOp(self, node)
[docs]class SympyNodeRenderer(NodeRenderer):
expression_ops = {
'Add': sympy.Add,
'Mult': sympy.Mul,
'Pow': sympy.Pow,
'Mod': sympy.Mod,
# Compare
'Lt': sympy.StrictLessThan,
'LtE': sympy.LessThan,
'Gt': sympy.StrictGreaterThan,
'GtE': sympy.GreaterThan,
'Eq': sympy.Eq,
'NotEq': sympy.Ne,
# Unary ops are handled manually
# Bool ops
'And': sympy.And,
'Or': sympy.Or}
[docs] def render_func(self, node):
if node.id in DEFAULT_FUNCTIONS:
f = DEFAULT_FUNCTIONS[node.id]
if f.sympy_func is not None and isinstance(f.sympy_func,
sympy.FunctionClass):
return f.sympy_func
# special workaround for the "int" function
if node.id == 'int':
return sympy.Function("int_")
else:
return sympy.Function(node.id)
[docs] def render_Call(self, node):
if len(node.keywords):
raise ValueError("Keyword arguments not supported.")
elif getattr(node, 'starargs', None) is not None:
raise ValueError("Variable number of arguments not supported")
elif getattr(node, 'kwargs', None) is not None:
raise ValueError("Keyword arguments not supported")
elif len(node.args) == 0:
return self.render_func(node.func)(sympy.Symbol('_vectorisation_idx'))
else:
return self.render_func(node.func)(*(self.render_node(arg)
for arg in node.args))
[docs] def render_Compare(self, node):
if len(node.comparators)>1:
raise SyntaxError("Can only handle single comparisons like a<b not a<b<c")
op = node.ops[0]
return self.expression_ops[op.__class__.__name__](self.render_node(node.left), self.render_node(node.comparators[0]))
[docs] def render_Name(self, node):
if node.id in DEFAULT_CONSTANTS:
c = DEFAULT_CONSTANTS[node.id]
return c.sympy_obj
elif node.id in ['t', 'dt']:
return sympy.Symbol(node.id, real=True, positive=True)
else:
return sympy.Symbol(node.id, real=True)
[docs] def render_NameConstant(self, node):
if node.value in [True, False]:
return node.value
else:
return str(node.value)
[docs] def render_Num(self, node):
return sympy.Float(node.n)
[docs] def render_BinOp(self, node):
op_name = node.op.__class__.__name__
# Sympy implements division and subtraction as multiplication/addition
if op_name == 'Div':
op = self.expression_ops['Mult']
return op(self.render_node(node.left),
1 / self.render_node(node.right))
elif op_name == 'FloorDiv':
op = self.expression_ops['Mult']
left = self.render_node(node.left)
right = self.render_node(node.right)
return sympy.floor(op(left, 1 / right))
elif op_name == 'Sub':
op = self.expression_ops['Add']
return op(self.render_node(node.left),
-self.render_node(node.right))
else:
op = self.expression_ops[op_name]
return op(self.render_node(node.left), self.render_node(node.right))
[docs] def render_BoolOp(self, node):
op = self.expression_ops[node.op.__class__.__name__]
return op(*(self.render_node(value) for value in node.values))
[docs] def render_UnaryOp(self, node):
op_name = node.op.__class__.__name__
if op_name == 'UAdd':
# Nothing to do
return self.render_node(node.operand)
elif op_name == 'USub':
return -self.render_node(node.operand)
elif op_name == 'Not':
return sympy.Not(self.render_node(node.operand))
else:
raise ValueError('Unknown unary operator: ' + op_name)
[docs]class CPPNodeRenderer(NodeRenderer):
expression_ops = NodeRenderer.expression_ops.copy()
expression_ops.update({
# Unary ops
'Not': '!',
# Bool ops
'And': '&&',
'Or': '||',
# C does not have a floor division operator (but see render_BinOp)
'FloorDiv': '/',
})
[docs] def render_BinOp(self, node):
if node.op.__class__.__name__ == 'Pow':
return '_brian_pow(%s, %s)' % (self.render_node(node.left),
self.render_node(node.right))
elif node.op.__class__.__name__ == 'Mod':
return '_brian_mod(%s, %s)' % (self.render_node(node.left),
self.render_node(node.right))
elif node.op.__class__.__name__ == 'Div':
# C uses integer division, this is a quick and dirty way to assure
# it uses floating point division for integers
return '1.0f*%s/%s' % (self.render_element_parentheses(node.left),
self.render_element_parentheses(node.right))
elif node.op.__class__.__name__ == 'FloorDiv':
return '_brian_floordiv(%s, %s)' % (self.render_node(node.left),
self.render_node(node.right))
else:
return NodeRenderer.render_BinOp(self, node)
[docs] def render_NameConstant(self, node):
# In Python 3.4, None, True and False go here
return {True: 'true',
False: 'false'}.get(node.value, node.value)
[docs] def render_Name(self, node):
# Replace Python's True and False with their C++ bool equivalents
return {'True': 'true',
'False': 'false',
'inf': 'INFINITY'}.get(node.id, node.id)
[docs] def render_Assign(self, node):
return NodeRenderer.render_Assign(self, node)+';'