"""
Module containg the StateUpdateMethod for integration using the ODE solver
provided in the GNU Scientific Library (GSL)
"""
import sys
from brian2.utils.logger import get_logger
from ..core.preferences import prefs
from ..devices.device import RuntimeDevice, all_devices, auto_target
from .base import (
StateUpdateMethod,
UnsupportedEquationsException,
extract_method_options,
)
logger = get_logger(__name__)
__all__ = ["gsl_rk2", "gsl_rk4", "gsl_rkf45", "gsl_rkck", "gsl_rk8pd"]
default_method_options = {
"adaptable_timestep": True,
"absolute_error": 1e-6,
"absolute_error_per_variable": None,
"max_steps": 100,
"use_last_timestep": True,
"save_failed_steps": False,
"save_step_count": False,
}
[docs]
class GSLContainer:
"""
Class that contains information (equation- or integrator-related) required
for later code generation
"""
def __init__(
self,
method_options,
integrator,
abstract_code=None,
needed_variables=None,
variable_flags=None,
):
if needed_variables is None:
needed_variables = []
if variable_flags is None:
variable_flags = []
self.method_options = method_options
self.integrator = integrator
self.abstract_code = abstract_code
self.needed_variables = needed_variables
self.variable_flags = variable_flags
[docs]
def get_codeobj_class(self):
"""
Return codeobject class based on target language and device.
Choose which version of the GSL `CodeObject` to use. If
```isinstance(device, CPPStandaloneDevice)```, then
we want the `GSLCPPStandaloneCodeObject`. Otherwise the return value is
based on prefs.codegen.target.
Returns
-------
code_object : class
The respective `CodeObject` class (i.e. either
`GSLCythonCodeObject` or `GSLCPPStandaloneCodeObject`).
"""
# imports in this function to avoid circular imports
from brian2.devices.cpp_standalone.device import CPPStandaloneDevice
from brian2.devices.device import get_device
from ..codegen.runtime.GSLcython_rt import GSLCythonCodeObject
device = get_device()
if (
device.__class__ is CPPStandaloneDevice
): # We do not want to accept subclasses here
from ..devices.cpp_standalone.GSLcodeobject import (
GSLCPPStandaloneCodeObject,
)
# In runtime mode (i.e. Cython), the compiler settings are
# added for each `CodeObject` (only the files that use the GSL are
# linked to the GSL). However, in C++ standalone mode, there are global
# compiler settings that are used for all files (stored in the
# `CPPStandaloneDevice`). Furthermore, header file includes are directly
# inserted into the template instead of added during the compilation
# phase. Therefore, we have to add the options here
# instead of in `GSLCPPStandaloneCodeObject`
# Add the GSL library if it has not yet been added
if "gsl" not in device.libraries:
device.libraries += ["gsl", "gslcblas"]
device.headers += [
"<stdio.h>",
"<stdlib.h>",
"<gsl/gsl_odeiv2.h>",
"<gsl/gsl_errno.h>",
"<gsl/gsl_matrix.h>",
]
if sys.platform == "win32":
device.define_macros += [("WIN32", "1"), ("GSL_DLL", "1")]
if prefs.GSL.directory is not None:
device.include_dirs += [prefs.GSL.directory]
return GSLCPPStandaloneCodeObject
elif isinstance(device, RuntimeDevice):
if prefs.codegen.target == "auto":
target_name = auto_target().class_name
else:
target_name = prefs.codegen.target
if target_name == "cython":
return GSLCythonCodeObject
raise NotImplementedError(
(
"GSL integration has not been implemented for "
"for the '{target_name}' code generation target."
"\nUse the 'cython' code generation target, "
"or switch to the 'cpp_standalone' device."
).format(target_name=target_name)
)
else:
device_name = [name for name, dev in all_devices.items() if dev is device]
assert len(device_name) == 1
raise NotImplementedError(
(
"GSL integration has not been implemented for "
"for the '{device}' device."
"\nUse either the 'cpp_standalone' device, "
"or the runtime device with target language "
"'cython'."
).format(device=device_name[0])
)
[docs]
def __call__(self, obj):
"""
Transfer the code object class saved in self to the object sent as an argument.
This method is returned when calling `GSLStateUpdater`. This class inherits
from `StateUpdateMethod` which orignally only returns abstract code.
However, with GSL this returns a method because more is needed than just
the abstract code: the state updater requires its own CodeObject that is
different from the other `NeuronGroup` objects. This method adds this
`CodeObject` to the `StateUpdater` object (and also adds the variables
't', 'dt', and other variables that are needed in the `GSLCodeGenerator`.
Parameters
----------
obj : `GSLStateUpdater`
the object that the codeobj_class and other variables need to be transferred to
Returns
-------
abstract_code : str
The abstract code (translated equations), that is returned conventionally
by brian and used for later code generation in the `CodeGenerator.translate` method.
"""
obj.codeobj_class = self.get_codeobj_class()
obj._gsl_variable_flags = self.variable_flags
obj.method_options = self.method_options
obj.integrator = self.integrator
obj.needed_variables = ["t", "dt"] + self.needed_variables
return self.abstract_code
[docs]
class GSLStateUpdater(StateUpdateMethod):
"""
A statupdater that rewrites the differential equations so that the GSL generator
knows how to write the code in the target language.
.. versionadded:: 2.1
"""
def __init__(self, integrator):
self.integrator = integrator
[docs]
def __call__(self, equations, variables=None, method_options=None):
"""
Translate equations to abstract_code.
Parameters
----------
equations : `Equations`
object containing the equations that describe the ODE systemTransferClass(self)
variables : dict
dictionary containing str, `Variable` pairs
Returns
-------
method : callable
Method that needs to be called with `StateUpdater` to add CodeObject
class and some other variables so these can be sent to the `CodeGenerator`
"""
logger.warn(
"Integrating equations with GSL is still considered experimental", once=True
)
method_options = extract_method_options(method_options, default_method_options)
if equations.is_stochastic:
raise UnsupportedEquationsException(
"Cannot solve stochastic equations with the GSL state updater."
)
# the approach is to 'tag' the differential equation variables so they can
# be translated to GSL code
diff_eqs = equations.get_substituted_expressions(variables)
code = []
count_statevariables = 0
counter = {}
diff_vars = []
for diff_name, expr in diff_eqs:
# if diff_name does not occur in the right hand side of the equation, Brian does not
# know to add the variable to the namespace, so we add it to needed_variables
diff_vars += [diff_name]
counter[diff_name] = count_statevariables
code += [f"_gsl_{diff_name}_f{counter[diff_name]} = {expr}"]
count_statevariables += 1
# add flags to variables objects because some of them we need in the GSL generator
flags = {}
for eq_name, eq_obj in equations._equations.items():
if len(eq_obj.flags) > 0:
flags[eq_name] = eq_obj.flags
return GSLContainer(
method_options=method_options,
integrator=self.integrator,
abstract_code=("\n").join(code),
needed_variables=diff_vars,
variable_flags=flags,
)
gsl_rk2 = GSLStateUpdater("rk2")
gsl_rk4 = GSLStateUpdater("rk4")
gsl_rkf45 = GSLStateUpdater("rkf45")
gsl_rkck = GSLStateUpdater("rkck")
gsl_rk8pd = GSLStateUpdater("rk8pd")