import itertools
import numpy as np
from brian2.core.functions import DEFAULT_FUNCTIONS, timestep
from brian2.core.variables import ArrayVariable
from brian2.parsing.bast import brian_dtype_from_dtype
from brian2.parsing.rendering import NumpyNodeRenderer
from brian2.utils.logger import get_logger
from brian2.utils.stringtools import get_identifiers, indent, word_substitute
from .base import CodeGenerator
__all__ = ["NumpyCodeGenerator"]
logger = get_logger(__name__)
[docs]class VectorisationError(Exception):
pass
[docs]class NumpyCodeGenerator(CodeGenerator):
"""
Numpy language
Essentially Python but vectorised.
"""
class_name = "numpy"
_use_ufunc_at_vectorisation = True # allow this to be off for testing only
[docs] def translate_expression(self, expr):
expr = word_substitute(expr, self.func_name_replacements)
return (
NumpyNodeRenderer(auto_vectorise=self.auto_vectorise)
.render_expr(expr, self.variables)
.strip()
)
[docs] def translate_statement(self, statement):
# TODO: optimisation, translate arithmetic to a sequence of inplace
# operations like a=b+c -> add(b, c, a)
var, op, expr, comment = (
statement.var,
statement.op,
statement.expr,
statement.comment,
)
if op == ":=":
op = "="
# For numpy we replace complex expressions involving a single boolean variable into a
# where(boolvar, expr_if_true, expr_if_false)
if (
statement.used_boolean_variables is not None
and len(statement.used_boolean_variables) == 1
and brian_dtype_from_dtype(statement.dtype) == "float"
and statement.complexity_std > sum(statement.complexities.values())
):
used_boolvars = statement.used_boolean_variables
bool_simp = statement.boolean_simplified_expressions
boolvar = used_boolvars[0]
for bool_assigns, simp_expr in bool_simp.items():
_, boolval = bool_assigns[0]
if boolval:
expr_true = simp_expr
else:
expr_false = simp_expr
code = f"{var} {op} _numpy.where({boolvar}, {expr_true}, {expr_false})"
else:
code = f"{var} {op} {self.translate_expression(expr)}"
if len(comment):
code += f" # {comment}"
return code
[docs] def ufunc_at_vectorisation(
self,
statement,
variables,
indices,
conditional_write_vars,
created_vars,
used_variables,
):
if not self._use_ufunc_at_vectorisation:
raise VectorisationError()
# Avoids circular import
from brian2.devices.device import device
# See https://github.com/brian-team/brian2/pull/531 for explanation
used = set(get_identifiers(statement.expr))
used = used.intersection(
k for k in list(variables.keys()) if k in indices and indices[k] != "_idx"
)
used_variables.update(used)
if statement.var in used_variables:
raise VectorisationError()
expr = NumpyNodeRenderer(auto_vectorise=self.auto_vectorise).render_expr(
statement.expr
)
if (
statement.op == ":="
or indices[statement.var] == "_idx"
or not statement.inplace
):
if statement.op == ":=":
op = "="
else:
op = statement.op
line = f"{statement.var} {op} {expr}"
elif statement.inplace:
if statement.op == "+=":
ufunc_name = "_numpy.add"
elif statement.op == "*=":
ufunc_name = "_numpy.multiply"
elif statement.op == "/=":
ufunc_name = "_numpy.divide"
elif statement.op == "-=":
ufunc_name = "_numpy.subtract"
else:
raise VectorisationError()
array_name = device.get_array_name(variables[statement.var])
idx = indices[statement.var]
line = f"{ufunc_name}.at({array_name}, {idx}, {expr})"
line = self.conditional_write(
line,
statement,
variables,
conditional_write_vars=conditional_write_vars,
created_vars=created_vars,
)
else:
raise VectorisationError()
if len(statement.comment):
line += f" # {statement.comment}"
return line
[docs] def vectorise_code(self, statements, variables, variable_indices, index="_idx"):
created_vars = {stmt.var for stmt in statements if stmt.op == ":="}
try:
lines = []
used_variables = set()
for statement in statements:
lines.append(
"# Abstract code: "
f" {statement.var} {statement.op} {statement.expr}"
)
# We treat every statement individually with its own read and write code
# to be on the safe side
read, write, indices, conditional_write_vars = self.arrays_helper(
[statement]
)
# We make sure that we only add code to `lines` after it went
# through completely
ufunc_lines = []
# No need to load a variable if it is only in read because of
# the in-place operation
if (
statement.inplace
and variable_indices[statement.var] != "_idx"
and statement.var not in get_identifiers(statement.expr)
):
read = read - {statement.var}
ufunc_lines.extend(
self.read_arrays(read, write, indices, variables, variable_indices)
)
ufunc_lines.append(
self.ufunc_at_vectorisation(
statement,
variables,
variable_indices,
conditional_write_vars,
created_vars,
used_variables,
)
)
# Do not write back such values, the ufuncs have modified the
# underlying array already
if statement.inplace and variable_indices[statement.var] != "_idx":
write = write - {statement.var}
ufunc_lines.extend(
self.write_arrays(
[statement], read, write, variables, variable_indices
)
)
lines.extend(ufunc_lines)
except VectorisationError:
if self._use_ufunc_at_vectorisation:
logger.info(
"Failed to vectorise code, falling back on Python loop: note that"
" this will be very slow! Switch to another code generation target"
" for best performance (e.g. cython). First line is: "
+ str(statements[0]),
once=True,
)
lines = []
lines.extend(
[
"_full_idx = _idx",
"for _idx in _full_idx:",
" _vectorisation_idx = _idx",
]
)
read, write, indices, conditional_write_vars = self.arrays_helper(
statements
)
lines.extend(
indent(code)
for code in self.read_arrays(
read, write, indices, variables, variable_indices
)
)
for statement in statements:
line = self.translate_statement(statement)
if statement.var in conditional_write_vars:
lines.append(indent(f"if {conditional_write_vars[statement.var]}:"))
lines.append(indent(line, 2))
else:
lines.append(indent(line))
lines.extend(
indent(code)
for code in self.write_arrays(
statements, read, write, variables, variable_indices
)
)
return lines
[docs] def read_arrays(self, read, write, indices, variables, variable_indices):
# index and read arrays (index arrays first)
lines = []
for varname in itertools.chain(indices, read):
var = variables[varname]
index = variable_indices[varname]
# if index in iterate_all:
# line = '{varname} = {array_name}'
# else:
# line = '{varname} = {array_name}.take({index})'
# line = line.format(varname=varname, array_name=self.get_array_name(var), index=index)
line = f"{varname} = {self.get_array_name(var)}"
if index not in self.iterate_all:
line += f"[{index}]"
elif varname in write:
# avoid potential issues with aliased variables, see github #259
line += ".copy()"
lines.append(line)
return lines
[docs] def write_arrays(self, statements, read, write, variables, variable_indices):
# write arrays
lines = []
for varname in write:
var = variables[varname]
index_var = variable_indices[varname]
# check if all operations were inplace and we're operating on the
# whole vector, if so we don't need to write the array back
if index_var not in self.iterate_all or varname in read:
all_inplace = False
else:
all_inplace = True
for stmt in statements:
if stmt.var == varname and not stmt.inplace:
all_inplace = False
break
if not all_inplace:
line = self.get_array_name(var)
if index_var in self.iterate_all:
line = f"{line}[:]"
else:
line = f"{line}[{index_var}]"
line = f"{line} = {varname}"
lines.append(line)
return lines
[docs] def conditional_write(
self, line, stmt, variables, conditional_write_vars, created_vars
):
if stmt.var in conditional_write_vars:
subs = {}
index = conditional_write_vars[stmt.var]
# we replace all var with var[index], but actually we use this repl_string first because
# we don't want to end up with lines like x[not_refractory[not_refractory]] when
# multiple substitution passes are invoked
repl_string = ( # this string shouldn't occur anywhere I hope! :)
"#$(@#&$@$*U#@)$@(#"
)
for varname, var in list(variables.items()):
if isinstance(var, ArrayVariable) and not var.scalar:
subs[varname] = f"{varname}[{repl_string}]"
# all newly created vars are arrays and will need indexing
for varname in created_vars:
subs[varname] = f"{varname}[{repl_string}]"
# Also index _vectorisation_idx so that e.g. rand() works correctly
subs["_vectorisation_idx"] = f"_vectorisation_idx[{repl_string}]"
line = word_substitute(line, subs)
line = line.replace(repl_string, index)
return line
[docs] def translate_one_statement_sequence(self, statements, scalar=False):
variables = self.variables
variable_indices = self.variable_indices
read, write, indices, conditional_write_vars = self.arrays_helper(statements)
lines = []
all_unique = not self.has_repeated_indices(statements)
if scalar or all_unique:
# Simple translation
lines.extend(
self.read_arrays(read, write, indices, variables, variable_indices)
)
created_vars = {stmt.var for stmt in statements if stmt.op == ":="}
for stmt in statements:
line = self.translate_statement(stmt)
line = self.conditional_write(
line, stmt, variables, conditional_write_vars, created_vars
)
lines.append(line)
lines.extend(
self.write_arrays(statements, read, write, variables, variable_indices)
)
else:
# More complex translation to deal with repeated indices
lines.extend(self.vectorise_code(statements, variables, variable_indices))
return lines
################################################################################
# Implement functions
################################################################################
# Functions that exist under the same name in numpy
for func_name, func in [
("sin", np.sin),
("cos", np.cos),
("tan", np.tan),
("sinh", np.sinh),
("cosh", np.cosh),
("tanh", np.tanh),
("exp", np.exp),
("log", np.log),
("log10", np.log10),
("sqrt", np.sqrt),
("arcsin", np.arcsin),
("arccos", np.arccos),
("arctan", np.arctan),
("abs", np.abs),
("sign", np.sign),
]:
DEFAULT_FUNCTIONS[func_name].implementations.add_implementation(
NumpyCodeGenerator, code=func
)
# Functions that are implemented in a somewhat special way
[docs]def randn_func(vectorisation_idx):
try:
N = len(vectorisation_idx)
return np.random.randn(N)
except TypeError:
# scalar value
return np.random.randn()
[docs]def rand_func(vectorisation_idx):
try:
N = len(vectorisation_idx)
return np.random.rand(N)
except TypeError:
# scalar value
return np.random.rand()
[docs]def poisson_func(lam, vectorisation_idx):
try:
N = len(vectorisation_idx)
return np.random.poisson(lam, size=N)
except TypeError:
# scalar value
return np.random.poisson(lam)
DEFAULT_FUNCTIONS["randn"].implementations.add_implementation(
NumpyCodeGenerator, code=randn_func
)
DEFAULT_FUNCTIONS["rand"].implementations.add_implementation(
NumpyCodeGenerator, code=rand_func
)
DEFAULT_FUNCTIONS["poisson"].implementations.add_implementation(
NumpyCodeGenerator, code=poisson_func
)
clip_func = lambda array, a_min, a_max: np.clip(array, a_min, a_max)
DEFAULT_FUNCTIONS["clip"].implementations.add_implementation(
NumpyCodeGenerator, code=clip_func
)
int_func = lambda value: np.int32(value)
DEFAULT_FUNCTIONS["int"].implementations.add_implementation(
NumpyCodeGenerator, code=int_func
)
ceil_func = lambda value: np.int32(np.ceil(value))
DEFAULT_FUNCTIONS["ceil"].implementations.add_implementation(
NumpyCodeGenerator, code=ceil_func
)
floor_func = lambda value: np.int32(np.floor(value))
DEFAULT_FUNCTIONS["floor"].implementations.add_implementation(
NumpyCodeGenerator, code=floor_func
)
# We need to explicitly add an implementation for the timestep function,
# otherwise Brian would *add* units during simulation, thinking that the
# timestep function would not work correctly otherwise. This would slow the
# function down significantly.
DEFAULT_FUNCTIONS["timestep"].implementations.add_implementation(
NumpyCodeGenerator, code=timestep
)