"""
Module providing the base `CodeObject` and related functions.
"""
import collections
import copy
import platform
import weakref
from brian2.core.functions import Function
from brian2.core.names import Nameable
from brian2.equations.unitcheck import check_units_statements
from brian2.utils.logger import get_logger
from brian2.utils.stringtools import code_representation, indent
from .translation import analyse_identifiers
__all__ = ["CodeObject", "constant_or_scalar"]
logger = get_logger(__name__)
#: Dictionary with basic information about the current system (OS, etc.)
sys_info = {
"system": platform.system(),
"architecture": platform.architecture(),
"machine": platform.machine(),
}
[docs]def constant_or_scalar(varname, variable):
"""
Convenience function to generate code to access the value of a variable.
Will return ``'varname'`` if the ``variable`` is a constant, and
``array_name[0]`` if it is a scalar array.
"""
from brian2.devices.device import get_device # avoid circular import
if variable.array:
return f"{get_device().get_array_name(variable)}[0]"
else:
return f"{varname}"
[docs]class CodeObject(Nameable):
"""
Executable code object.
The ``code`` can either be a string or a
`brian2.codegen.templates.MultiTemplate`.
After initialisation, the code is compiled with the given namespace
using ``code.compile(namespace)``.
Calling ``code(key1=val1, key2=val2)`` executes the code with the given
variables inserted into the namespace.
"""
#: The `CodeGenerator` class used by this `CodeObject`
generator_class = None
#: A short name for this type of `CodeObject`
class_name = None
def __init__(
self,
owner,
code,
variables,
variable_indices,
template_name,
template_source,
compiler_kwds,
name="codeobject*",
):
Nameable.__init__(self, name=name)
try:
owner = weakref.proxy(owner)
except TypeError:
pass # if owner was already a weakproxy then this will be the error raised
self.owner = owner
self.code = code
self.compiled_code = {}
self.variables = variables
self.variable_indices = variable_indices
self.template_name = template_name
self.template_source = template_source
self.compiler_kwds = compiler_kwds
[docs] @classmethod
def is_available(cls):
"""
Whether this target for code generation is available. Should use a
minimal example to check whether code generation works in general.
"""
raise NotImplementedError(
f"CodeObject class {cls.__name__} is missing an 'is_available' method."
)
[docs] def update_namespace(self):
"""
Update the namespace for this timestep. Should only deal with variables
where *the reference* changes every timestep, i.e. where the current
reference in `namespace` is not correct.
"""
pass
[docs] def compile_block(self, block):
raise NotImplementedError("Implement compile_block method")
[docs] def compile(self):
for block in ["before_run", "run", "after_run"]:
self.compiled_code[block] = self.compile_block(block)
[docs] def __call__(self, **kwds):
self.update_namespace()
self.namespace.update(**kwds)
return self.run()
[docs] def run_block(self, block):
raise NotImplementedError("Implement run_block method")
[docs] def before_run(self):
"""
Runs the preparation code in the namespace. This code will only be
executed once per run.
Returns
-------
return_value : dict
A dictionary with the keys corresponding to the `output_variables`
defined during the call of `CodeGenerator.code_object`.
"""
return self.run_block("before_run")
[docs] def run(self):
"""
Runs the main code in the namespace.
Returns
-------
return_value : dict
A dictionary with the keys corresponding to the `output_variables`
defined during the call of `CodeGenerator.code_object`.
"""
return self.run_block("run")
[docs] def after_run(self):
"""
Runs the finalizing code in the namespace. This code will only be
executed once per run.
Returns
-------
return_value : dict
A dictionary with the keys corresponding to the `output_variables`
defined during the call of `CodeGenerator.code_object`.
"""
return self.run_block("after_run")
def _error_msg(code, name):
"""
Little helper function for error messages.
"""
error_msg = f"Error generating code for code object '{name}' "
code_lines = [line for line in code.split("\n") if len(line.strip())]
# If the abstract code is only one line, display it in full
if len(code_lines) <= 1:
error_msg += f"from this abstract code: '{code_lines[0]}'\n"
else:
error_msg += (
f"from {len(code_lines)} lines of abstract code, first line is: "
"'code_lines[0]'\n"
)
return error_msg
[docs]def check_compiler_kwds(compiler_kwds, accepted_kwds, target):
"""
Internal function to check the provided compiler keywords against the list
of understood keywords.
Parameters
----------
compiler_kwds : dict
Dictionary of compiler keywords and respective list of values.
accepted_kwds : list of str
The compiler keywords understood by the code generation target
target : str
The name of the code generation target (used for the error message).
Raises
------
ValueError
If a compiler keyword is not understood
"""
for key in compiler_kwds:
if key not in accepted_kwds:
formatted_kwds = ", ".join(f"'{kw}'" for kw in accepted_kwds)
raise ValueError(
f"The keyword argument '{key}' is not understood by "
f"the code generation target '{target}'. The valid "
f"arguments are: {formatted_kwds}."
)
def _merge_compiler_kwds(list_of_kwds):
"""
Merges a list of keyword dictionaries. Values in these dictionaries are
lists of values, the merged dictionaries will contain the concatenations
of lists specified for the same key.
Parameters
----------
list_of_kwds : list of dict
A list of compiler keyword dictionaries that should be merged.
Returns
-------
merged_kwds : dict
The merged dictionary
"""
merged_kwds = collections.defaultdict(list)
for kwds in list_of_kwds:
for key, values in kwds.items():
if not isinstance(values, list):
raise TypeError(
f"Compiler keyword argument '{key}' requires a list of values."
)
merged_kwds[key].extend(values)
return merged_kwds
def _gather_compiler_kwds(function, codeobj_class):
"""
Gather all the compiler keywords for a function and its dependencies.
Parameters
----------
function : `Function`
The function for which the compiler keywords should be gathered
codeobj_class : type
The class of `CodeObject` to use
Returns
-------
kwds : dict
A dictionary with the compiler arguments, a list of values for each
key.
"""
implementation = function.implementations[codeobj_class]
all_kwds = [implementation.compiler_kwds]
if implementation.dependencies is not None:
for dependency in implementation.dependencies.values():
all_kwds.append(_gather_compiler_kwds(dependency, codeobj_class))
return _merge_compiler_kwds(all_kwds)
[docs]def create_runner_codeobj(
group,
code,
template_name,
run_namespace,
user_code=None,
variable_indices=None,
name=None,
check_units=True,
needed_variables=None,
additional_variables=None,
template_kwds=None,
override_conditional_write=None,
codeobj_class=None,
):
"""Create a `CodeObject` for the execution of code in the context of a
`Group`.
Parameters
----------
group : `Group`
The group where the code is to be run
code : str or dict of str
The code to be executed.
template_name : str
The name of the template to use for the code.
run_namespace : dict-like
An additional namespace that is used for variable lookup (either
an explicitly defined namespace or one taken from the local
context).
user_code : str, optional
The code that had been specified by the user before other code was
added automatically. If not specified, will be assumed to be identical
to ``code``.
variable_indices : dict-like, optional
A mapping from `Variable` objects to index names (strings). If none is
given, uses the corresponding attribute of `group`.
name : str, optional
A name for this code object, will use ``group + '_codeobject*'`` if
none is given.
check_units : bool, optional
Whether to check units in the statement. Defaults to ``True``.
needed_variables: list of str, optional
A list of variables that are neither present in the abstract code, nor
in the ``USES_VARIABLES`` statement in the template. This is only
rarely necessary, an example being a `StateMonitor` where the
names of the variables are neither known to the template nor included
in the abstract code statements.
additional_variables : dict-like, optional
A mapping of names to `Variable` objects, used in addition to the
variables saved in `group`.
template_kwds : dict, optional
A dictionary of additional information that is passed to the template.
override_conditional_write: list of str, optional
A list of variable names which are used as conditions (e.g. for
refractoriness) which should be ignored.
codeobj_class : class, optional
The `CodeObject` class to run code with. If not specified, defaults to
the `group`'s ``codeobj_class`` attribute.
"""
if name is None:
if group is not None:
name = f"{group.name}_{template_name}_codeobject*"
else:
name = f"{template_name}_codeobject*"
if user_code is None:
user_code = code
if isinstance(code, str):
code = {None: code}
user_code = {None: user_code}
msg = (
f"Creating code object (group={group.name}, template name={template_name}) for"
" abstract code:\n"
)
msg += indent(code_representation(code))
logger.diagnostic(msg)
from brian2.devices import get_device
device = get_device()
if override_conditional_write is None:
override_conditional_write = set()
else:
override_conditional_write = set(override_conditional_write)
if codeobj_class is None:
codeobj_class = device.code_object_class(group.codeobj_class)
else:
codeobj_class = device.code_object_class(codeobj_class)
template = getattr(codeobj_class.templater, template_name)
template_variables = getattr(template, "variables", None)
all_variables = dict(group.variables)
if additional_variables is not None:
all_variables.update(additional_variables)
# Determine the identifiers that were used
identifiers = set()
user_identifiers = set()
for v, u_v in zip(code.values(), user_code.values()):
_, uk, u = analyse_identifiers(v, all_variables, recursive=True)
identifiers |= uk | u
_, uk, u = analyse_identifiers(u_v, all_variables, recursive=True)
user_identifiers |= uk | u
# Add variables that are not in the abstract code, nor specified in the
# template but nevertheless necessary
if needed_variables is None:
needed_variables = []
# Resolve all variables (variables used in the code and variables needed by
# the template)
variables = group.resolve_all(
sorted(identifiers | set(needed_variables) | set(template_variables)),
# template variables are not known to the user:
user_identifiers=user_identifiers,
additional_variables=additional_variables,
run_namespace=run_namespace,
)
# We raise this error only now, because there is some non-obvious code path
# where Jinja tries to get a Synapse's "name" attribute via syn['name'],
# which then triggers the use of the `group_get_indices` template which does
# not exist for standalone. Putting the check for template == None here
# means we will first raise an error about the unknown identifier which will
# then make Jinja try syn.name
if template is None:
codeobj_class_name = codeobj_class.class_name or codeobj_class.__name__
raise AttributeError(
f"'{codeobj_class_name}' does not provide a code "
f"generation template '{template_name}'"
)
conditional_write_variables = {}
# Add all the "conditional write" variables
for var in variables.values():
cond_write_var = getattr(var, "conditional_write", None)
if cond_write_var in override_conditional_write:
continue
if cond_write_var is not None:
if (
cond_write_var.name in variables
and not variables[cond_write_var.name] is cond_write_var
):
logger.diagnostic(
f"Variable '{cond_write_var.name}' is needed for the "
"conditional write mechanism of variable "
f"'{var.name}'. Its name is already used for "
f"{variables[cond_write_var.name]!r}."
)
else:
conditional_write_variables[cond_write_var.name] = cond_write_var
variables.update(conditional_write_variables)
if check_units:
for c in code.values():
# This is the first time that the code is parsed, catch errors
try:
check_units_statements(c, variables)
except (SyntaxError, ValueError) as ex:
error_msg = _error_msg(c, name)
raise ValueError(error_msg) from ex
all_variable_indices = copy.copy(group.variables.indices)
if additional_variables is not None:
all_variable_indices.update(additional_variables.indices)
if variable_indices is not None:
all_variable_indices.update(variable_indices)
# Make "conditional write" variables use the same index as the variable
# that depends on them
for varname, var in variables.items():
cond_write_var = getattr(var, "conditional_write", None)
if cond_write_var is not None:
all_variable_indices[cond_write_var.name] = all_variable_indices[varname]
# Check that all functions are available
for varname, value in variables.items():
if isinstance(value, Function):
try:
value.implementations[codeobj_class]
except KeyError as ex:
# if we are dealing with numpy, add the default implementation
from brian2.codegen.runtime.numpy_rt import NumpyCodeObject
if codeobj_class is NumpyCodeObject:
value.implementations.add_numpy_implementation(value.pyfunc)
else:
raise NotImplementedError(
f"Cannot use function '{varname}': {ex}"
) from ex
# Gather the additional compiler arguments declared by function
# implementations
all_keywords = [
_gather_compiler_kwds(var, codeobj_class)
for var in variables.values()
if isinstance(var, Function)
]
compiler_kwds = _merge_compiler_kwds(all_keywords)
# Add the indices needed by the variables
for varname in list(variables):
var_index = all_variable_indices[varname]
if var_index not in ("_idx", "0"):
variables[var_index] = all_variables[var_index]
return device.code_object(
owner=group,
name=name,
abstract_code=code,
variables=variables,
template_name=template_name,
variable_indices=all_variable_indices,
template_kwds=template_kwds,
codeobj_class=codeobj_class,
override_conditional_write=override_conditional_write,
compiler_kwds=compiler_kwds,
)