Source code for brian2.parsing.rendering

import ast
import numbers

import numpy as np
import sympy

from brian2.core.functions import DEFAULT_CONSTANTS, DEFAULT_FUNCTIONS

__all__ = [
    "NodeRenderer",
    "NumpyNodeRenderer",
    "CPPNodeRenderer",
    "SympyNodeRenderer",
]


[docs] class NodeRenderer: expression_ops = { # BinOp "Add": "+", "Sub": "-", "Mult": "*", "Div": "/", "FloorDiv": "//", "Pow": "**", "Mod": "%", # Compare "Lt": "<", "LtE": "<=", "Gt": ">", "GtE": ">=", "Eq": "==", "NotEq": "!=", # Unary ops "Not": "not", "UAdd": "+", "USub": "-", # Bool ops "And": "and", "Or": "or", # Augmented assign "AugAdd": "+=", "AugSub": "-=", "AugMult": "*=", "AugDiv": "/=", "AugPow": "**=", "AugMod": "%=", } def __init__(self, auto_vectorise=None): if auto_vectorise is None: auto_vectorise = set() self.auto_vectorise = auto_vectorise
[docs] def render_expr(self, expr, strip=True): if strip: expr = expr.strip() node = ast.parse(expr, mode="eval") return self.render_node(node.body)
[docs] def render_node(self, node): nodename = node.__class__.__name__ methname = f"render_{nodename}" try: return getattr(self, methname)(node) except AttributeError: if nodename == "Subscript": raise SyntaxError( "Brian equations/expressions do not support indexing with '[...]'." ) elif nodename == "Attribute": raise SyntaxError( "Brian equations/expressions do not support accessing attributes" " with the '.' syntax." ) elif nodename == "Tuple": raise SyntaxError("Brian equations/expressions do not support tuples.") else: raise SyntaxError( f"Brian equations/expressions do not support the '{nodename}'" " syntax." )
[docs] def render_func(self, node): return self.render_Name(node)
[docs] def render_Name(self, node): return node.id
[docs] def render_Constant(self, node): if isinstance(node.value, np.number): # repr prints the dtype in numpy 2.0 return repr(node.value.item()) return repr(node.value)
[docs] def render_Call(self, node): if len(node.keywords): raise ValueError("Keyword arguments not supported.") elif getattr(node, "starargs", None) is not None: raise ValueError("Variable number of arguments not supported") elif getattr(node, "kwargs", None) is not None: raise ValueError("Keyword arguments not supported") else: if node.func.id in self.auto_vectorise: vectorisation_idx = ast.Name() vectorisation_idx.id = "_vectorisation_idx" args = node.args + [vectorisation_idx] else: args = node.args return f"{self.render_func(node.func)}({', '.join(self.render_node(arg) for arg in args)})"
[docs] def render_element_parentheses(self, node): """ Render an element with parentheses around it or leave them away for numbers, names and function calls. """ if node.__class__.__name__ == "Name": return self.render_node(node) elif node.__class__.__name__ in ["Num", "Constant"] and node.value >= 0: return self.render_node(node) elif node.__class__.__name__ == "Call": return self.render_node(node) else: return f"({self.render_node(node)})"
[docs] def render_BinOp_parentheses(self, left, right, op): # Use a simplified checking whether it is possible to omit parentheses: # only omit parentheses for numbers, variable names or function calls. # This means we still put needless parentheses because we ignore # precedence rules, e.g. we write "3 + (4 * 5)" but at least we do # not do "(3) + ((4) + (5))" op_class = op.__class__.__name__ # Give a more useful error message when using bit-wise operators if op_class in ["BitXor", "BitAnd", "BitOr"]: correction = { "BitXor": ("^", "**"), "BitAnd": ("&", "and"), "BitOr": ("|", "or"), }.get(op_class) raise SyntaxError( f'The operator "{correction[0]}" is not supported, use' f' "{correction[1]}" instead.' ) return ( f"{self.render_element_parentheses(left)} " f"{self.expression_ops[op_class]} " f"{self.render_element_parentheses(right)}" )
[docs] def render_BinOp(self, node): return self.render_BinOp_parentheses(node.left, node.right, node.op)
[docs] def render_BoolOp(self, node): op = self.expression_ops[node.op.__class__.__name__] return (f" {op} ").join( f"{self.render_element_parentheses(v)}" for v in node.values )
[docs] def render_Compare(self, node): if len(node.comparators) > 1: raise SyntaxError("Can only handle single comparisons like a<b not a<b<c") return self.render_BinOp_parentheses( node.left, node.comparators[0], node.ops[0] )
[docs] def render_UnaryOp(self, node): return f"{self.expression_ops[node.op.__class__.__name__]} {self.render_element_parentheses(node.operand)}"
[docs] def render_Assign(self, node): if len(node.targets) > 1: raise SyntaxError("Only support syntax like a=b not a=b=c") return f"{self.render_node(node.targets[0])} = {self.render_node(node.value)}"
[docs] def render_AugAssign(self, node): target = node.target.id rhs = self.render_node(node.value) op = self.expression_ops[f"Aug{node.op.__class__.__name__}"] return f"{target} {op} {rhs}"
[docs] class NumpyNodeRenderer(NodeRenderer): expression_ops = NodeRenderer.expression_ops.copy() expression_ops.update( { # Unary ops # We'll handle "not" explicitly below # Bool ops "And": "&", "Or": "|", } )
[docs] def render_UnaryOp(self, node): if node.op.__class__.__name__ == "Not": return f"logical_not({self.render_node(node.operand)})" else: return NodeRenderer.render_UnaryOp(self, node)
[docs] class SympyNodeRenderer(NodeRenderer): expression_ops = { "Add": sympy.Add, "Mult": sympy.Mul, "Pow": sympy.Pow, "Mod": sympy.Mod, # Compare "Lt": sympy.StrictLessThan, "LtE": sympy.LessThan, "Gt": sympy.StrictGreaterThan, "GtE": sympy.GreaterThan, "Eq": sympy.Eq, "NotEq": sympy.Ne, # Unary ops are handled manually # Bool ops "And": sympy.And, "Or": sympy.Or, }
[docs] def render_func(self, node): if node.id in DEFAULT_FUNCTIONS: f = DEFAULT_FUNCTIONS[node.id] if f.sympy_func is not None and isinstance( f.sympy_func, sympy.FunctionClass ): return f.sympy_func # special workaround for the "int" function if node.id == "int": return sympy.Function("int_") else: return sympy.Function(node.id)
[docs] def render_Call(self, node): if len(node.keywords): raise ValueError("Keyword arguments not supported.") elif getattr(node, "starargs", None) is not None: raise ValueError("Variable number of arguments not supported") elif getattr(node, "kwargs", None) is not None: raise ValueError("Keyword arguments not supported") elif len(node.args) == 0: return self.render_func(node.func)(sympy.Symbol("_placeholder_arg")) else: return self.render_func(node.func)( *(self.render_node(arg) for arg in node.args) )
[docs] def render_Compare(self, node): if len(node.comparators) > 1: raise SyntaxError("Can only handle single comparisons like a<b not a<b<c") op = node.ops[0] return self.expression_ops[op.__class__.__name__]( self.render_node(node.left), self.render_node(node.comparators[0]) )
[docs] def render_Name(self, node): if node.id in DEFAULT_CONSTANTS: c = DEFAULT_CONSTANTS[node.id] return c.sympy_obj elif node.id in ["t", "dt"]: return sympy.Symbol(node.id, real=True, positive=True) else: return sympy.Symbol(node.id, real=True)
[docs] def render_Constant(self, node): if node.value is True or node.value is False: return node.value elif isinstance(node.value, numbers.Integral): return sympy.Integer(node.value) elif isinstance(node.value, numbers.Number): return sympy.Float(node.value) else: return str(node.value)
[docs] def render_BinOp(self, node): op_name = node.op.__class__.__name__ # Sympy implements division and subtraction as multiplication/addition if op_name == "Div": op = self.expression_ops["Mult"] return op(self.render_node(node.left), 1 / self.render_node(node.right)) elif op_name == "FloorDiv": op = self.expression_ops["Mult"] left = self.render_node(node.left) right = self.render_node(node.right) return sympy.floor(op(left, 1 / right)) elif op_name == "Sub": op = self.expression_ops["Add"] return op(self.render_node(node.left), -self.render_node(node.right)) else: op = self.expression_ops[op_name] return op(self.render_node(node.left), self.render_node(node.right))
[docs] def render_BoolOp(self, node): op = self.expression_ops[node.op.__class__.__name__] return op(*(self.render_node(value) for value in node.values))
[docs] def render_UnaryOp(self, node): op_name = node.op.__class__.__name__ if op_name == "UAdd": # Nothing to do return self.render_node(node.operand) elif op_name == "USub": return -self.render_node(node.operand) elif op_name == "Not": return sympy.Not(self.render_node(node.operand)) else: raise ValueError(f"Unknown unary operator: {op_name}")
[docs] class CPPNodeRenderer(NodeRenderer): expression_ops = NodeRenderer.expression_ops.copy() expression_ops.update( { # Unary ops "Not": "!", # Bool ops "And": "&&", "Or": "||", # C does not have a floor division operator (but see render_BinOp) "FloorDiv": "/", } )
[docs] def render_BinOp(self, node): if node.op.__class__.__name__ == "Pow": return ( f"_brian_pow({self.render_node(node.left)}," f" {self.render_node(node.right)})" ) elif node.op.__class__.__name__ == "Mod": return ( f"_brian_mod({self.render_node(node.left)}," f" {self.render_node(node.right)})" ) elif node.op.__class__.__name__ == "Div": # C uses integer division, this is a quick and dirty way to assure # it uses floating point division for integers return f"1.0f*{self.render_element_parentheses(node.left)}/{self.render_element_parentheses(node.right)}" elif node.op.__class__.__name__ == "FloorDiv": return ( f"_brian_floordiv({self.render_node(node.left)}," f" {self.render_node(node.right)})" ) else: return NodeRenderer.render_BinOp(self, node)
[docs] def render_Constant(self, node): if node.value is True: return "true" elif node.value is False: return "false" else: return super().render_Constant(node)
[docs] def render_Name(self, node): if node.id == "inf": return "INFINITY" else: return node.id
[docs] def render_Assign(self, node): return f"{NodeRenderer.render_Assign(self, node)};"