Source code for brian2.stateupdaters.GSL

"""
Module containg the StateUpdateMethod for integration using the ODE solver
provided in the GNU Scientific Library (GSL)
"""

import sys

from brian2.utils.logger import get_logger

from ..core.preferences import prefs
from ..devices.device import RuntimeDevice, all_devices, auto_target
from .base import (
    StateUpdateMethod,
    UnsupportedEquationsException,
    extract_method_options,
)

logger = get_logger(__name__)

__all__ = ["gsl_rk2", "gsl_rk4", "gsl_rkf45", "gsl_rkck", "gsl_rk8pd"]

default_method_options = {
    "adaptable_timestep": True,
    "absolute_error": 1e-6,
    "absolute_error_per_variable": None,
    "max_steps": 100,
    "use_last_timestep": True,
    "save_failed_steps": False,
    "save_step_count": False,
}


[docs] class GSLContainer: """ Class that contains information (equation- or integrator-related) required for later code generation """ def __init__( self, method_options, integrator, abstract_code=None, needed_variables=None, variable_flags=None, ): if needed_variables is None: needed_variables = [] if variable_flags is None: variable_flags = [] self.method_options = method_options self.integrator = integrator self.abstract_code = abstract_code self.needed_variables = needed_variables self.variable_flags = variable_flags
[docs] def get_codeobj_class(self): """ Return codeobject class based on target language and device. Choose which version of the GSL `CodeObject` to use. If ```isinstance(device, CPPStandaloneDevice)```, then we want the `GSLCPPStandaloneCodeObject`. Otherwise the return value is based on prefs.codegen.target. Returns ------- code_object : class The respective `CodeObject` class (i.e. either `GSLCythonCodeObject` or `GSLCPPStandaloneCodeObject`). """ # imports in this function to avoid circular imports from brian2.devices.cpp_standalone.device import CPPStandaloneDevice from brian2.devices.device import get_device from ..codegen.runtime.GSLcython_rt import GSLCythonCodeObject device = get_device() if ( device.__class__ is CPPStandaloneDevice ): # We do not want to accept subclasses here from ..devices.cpp_standalone.GSLcodeobject import ( GSLCPPStandaloneCodeObject, ) # In runtime mode (i.e. Cython), the compiler settings are # added for each `CodeObject` (only the files that use the GSL are # linked to the GSL). However, in C++ standalone mode, there are global # compiler settings that are used for all files (stored in the # `CPPStandaloneDevice`). Furthermore, header file includes are directly # inserted into the template instead of added during the compilation # phase. Therefore, we have to add the options here # instead of in `GSLCPPStandaloneCodeObject` # Add the GSL library if it has not yet been added if "gsl" not in device.libraries: device.libraries += ["gsl", "gslcblas"] device.headers += [ "<stdio.h>", "<stdlib.h>", "<gsl/gsl_odeiv2.h>", "<gsl/gsl_errno.h>", "<gsl/gsl_matrix.h>", ] if sys.platform == "win32": device.define_macros += [("WIN32", "1"), ("GSL_DLL", "1")] if prefs.GSL.directory is not None: device.include_dirs += [prefs.GSL.directory] return GSLCPPStandaloneCodeObject elif isinstance(device, RuntimeDevice): if prefs.codegen.target == "auto": target_name = auto_target().class_name else: target_name = prefs.codegen.target if target_name == "cython": return GSLCythonCodeObject raise NotImplementedError( ( "GSL integration has not been implemented for " "for the '{target_name}' code generation target." "\nUse the 'cython' code generation target, " "or switch to the 'cpp_standalone' device." ).format(target_name=target_name) ) else: device_name = [name for name, dev in all_devices.items() if dev is device] assert len(device_name) == 1 raise NotImplementedError( ( "GSL integration has not been implemented for " "for the '{device}' device." "\nUse either the 'cpp_standalone' device, " "or the runtime device with target language " "'cython'." ).format(device=device_name[0]) )
[docs] def __call__(self, obj): """ Transfer the code object class saved in self to the object sent as an argument. This method is returned when calling `GSLStateUpdater`. This class inherits from `StateUpdateMethod` which orignally only returns abstract code. However, with GSL this returns a method because more is needed than just the abstract code: the state updater requires its own CodeObject that is different from the other `NeuronGroup` objects. This method adds this `CodeObject` to the `StateUpdater` object (and also adds the variables 't', 'dt', and other variables that are needed in the `GSLCodeGenerator`. Parameters ---------- obj : `GSLStateUpdater` the object that the codeobj_class and other variables need to be transferred to Returns ------- abstract_code : str The abstract code (translated equations), that is returned conventionally by brian and used for later code generation in the `CodeGenerator.translate` method. """ obj.codeobj_class = self.get_codeobj_class() obj._gsl_variable_flags = self.variable_flags obj.method_options = self.method_options obj.integrator = self.integrator obj.needed_variables = ["t", "dt"] + self.needed_variables return self.abstract_code
[docs] class GSLStateUpdater(StateUpdateMethod): """ A statupdater that rewrites the differential equations so that the GSL generator knows how to write the code in the target language. .. versionadded:: 2.1 """ def __init__(self, integrator): self.integrator = integrator
[docs] def __call__(self, equations, variables=None, method_options=None): """ Translate equations to abstract_code. Parameters ---------- equations : `Equations` object containing the equations that describe the ODE systemTransferClass(self) variables : dict dictionary containing str, `Variable` pairs Returns ------- method : callable Method that needs to be called with `StateUpdater` to add CodeObject class and some other variables so these can be sent to the `CodeGenerator` """ logger.warn( "Integrating equations with GSL is still considered experimental", once=True ) method_options = extract_method_options(method_options, default_method_options) if equations.is_stochastic: raise UnsupportedEquationsException( "Cannot solve stochastic equations with the GSL state updater." ) # the approach is to 'tag' the differential equation variables so they can # be translated to GSL code diff_eqs = equations.get_substituted_expressions(variables) code = [] count_statevariables = 0 counter = {} diff_vars = [] for diff_name, expr in diff_eqs: # if diff_name does not occur in the right hand side of the equation, Brian does not # know to add the variable to the namespace, so we add it to needed_variables diff_vars += [diff_name] counter[diff_name] = count_statevariables code += [f"_gsl_{diff_name}_f{counter[diff_name]} = {expr}"] count_statevariables += 1 # add flags to variables objects because some of them we need in the GSL generator flags = {} for eq_name, eq_obj in equations._equations.items(): if len(eq_obj.flags) > 0: flags[eq_name] = eq_obj.flags return GSLContainer( method_options=method_options, integrator=self.integrator, abstract_code=("\n").join(code), needed_variables=diff_vars, variable_flags=flags, )
gsl_rk2 = GSLStateUpdater("rk2") gsl_rk4 = GSLStateUpdater("rk4") gsl_rkf45 = GSLStateUpdater("rkf45") gsl_rkck = GSLStateUpdater("rkck") gsl_rk8pd = GSLStateUpdater("rk8pd")