"""
Numerical integration functions.
"""
import operator
import string
from functools import reduce
import sympy
from pyparsing import (
Group,
Literal,
ParseException,
Suppress,
Word,
ZeroOrMore,
restOfLine,
)
from sympy.core.sympify import SympifyError
from brian2.equations.codestrings import is_constant_over_dt
from brian2.parsing.sympytools import str_to_sympy, sympy_to_str
from .base import (
StateUpdateMethod,
UnsupportedEquationsException,
extract_method_options,
)
__all__ = ["milstein", "heun", "euler", "rk2", "rk4", "ExplicitStateUpdater"]
# ===============================================================================
# Class for simple definition of explicit state updaters
# ===============================================================================
def _symbol(name, positive=None):
"""Shorthand for ``sympy.Symbol(name, real=True)``."""
return sympy.Symbol(name, real=True, positive=positive)
#: reserved standard symbols
SYMBOLS = {
"__x": _symbol("__x"),
"__t": _symbol("__t", positive=True),
"dt": _symbol("dt", positive=True),
"t": _symbol("t", positive=True),
"__f": sympy.Function("__f"),
"__g": sympy.Function("__g"),
"__dW": _symbol("__dW"),
}
[docs]def split_expression(expr):
"""
Split an expression into a part containing the function ``f`` and another
one containing the function ``g``. Returns a tuple of the two expressions
(as sympy expressions).
Parameters
----------
expr : str
An expression containing references to functions ``f`` and ``g``.
Returns
-------
(non_stochastic, stochastic) : tuple of sympy expressions
A pair of expressions representing the non-stochastic (containing
function-independent terms and terms involving ``f``) and the
stochastic part of the expression (terms involving ``g`` and/or ``dW``).
Examples
--------
>>> split_expression('dt * __f(__x, __t)')
(dt*__f(__x, __t), None)
>>> split_expression('dt * __f(__x, __t) + __dW * __g(__x, __t)')
(dt*__f(__x, __t), __dW*__g(__x, __t))
>>> split_expression('1/(2*sqrt(dt))*(__g_support - __g(__x, __t))*(sqrt(__dW))')
(0, sqrt(__dW)*__g_support/(2*sqrt(dt)) - sqrt(__dW)*__g(__x, __t)/(2*sqrt(dt)))
"""
f = SYMBOLS["__f"]
g = SYMBOLS["__g"]
dW = SYMBOLS["__dW"]
# Arguments of the f and g functions
x_f = sympy.Wild("x_f", exclude=[f, g], real=True)
t_f = sympy.Wild("t_f", exclude=[f, g], real=True)
x_g = sympy.Wild("x_g", exclude=[f, g], real=True)
t_g = sympy.Wild("t_g", exclude=[f, g], real=True)
# Reorder the expression so that f(x,t) and g(x,t) are factored out
sympy_expr = sympy.sympify(expr, locals=SYMBOLS).expand()
sympy_expr = sympy.collect(sympy_expr, f(x_f, t_f))
sympy_expr = sympy.collect(sympy_expr, g(x_g, t_g))
# Constant part, contains neither f, g nor dW
independent = sympy.Wild("independent", exclude=[f, g, dW], real=True)
# The exponent of the random number
dW_exponent = sympy.Wild("dW_exponent", exclude=[f, g, dW, 0], real=True)
# The factor for the random number, not containing the g function
independent_dW = sympy.Wild("independent_dW", exclude=[f, g, dW], real=True)
# The factor for the f function
f_factor = sympy.Wild("f_factor", exclude=[f, g], real=True)
# The factor for the g function
g_factor = sympy.Wild("g_factor", exclude=[f, g], real=True)
match_expr = (
independent
+ f_factor * f(x_f, t_f)
+ independent_dW * dW**dW_exponent
+ g_factor * g(x_g, t_g)
)
matches = sympy_expr.match(match_expr)
if matches is None:
raise ValueError(
f'Expression "{sympy_expr}" in the state updater description could not be'
" parsed."
)
# Non-stochastic part
if x_f in matches:
# Includes the f function
non_stochastic = matches[independent] + (
matches[f_factor] * f(matches[x_f], matches[t_f])
)
else:
# Does not include f, might be 0
non_stochastic = matches[independent]
# Stochastic part
if independent_dW in matches and matches[independent_dW] != 0:
# includes a random variable term with a non-zero factor
stochastic = (
matches[g_factor] * g(matches[x_g], matches[t_g])
+ matches[independent_dW] * dW ** matches[dW_exponent]
)
elif x_g in matches:
# Does not include a random variable but the g function
stochastic = matches[g_factor] * g(matches[x_g], matches[t_g])
else:
# Contains neither random variable nor g function --> empty
stochastic = None
return (non_stochastic, stochastic)
[docs]class ExplicitStateUpdater(StateUpdateMethod):
"""
An object that can be used for defining state updaters via a simple
description (see below). Resulting instances can be passed to the
``method`` argument of the `NeuronGroup` constructor. As other state
updater functions the `ExplicitStateUpdater` objects are callable,
returning abstract code when called with an `Equations` object.
A description of an explicit state updater consists of a (multi-line)
string, containing assignments to variables and a final "x_new = ...",
stating the integration result for a single timestep. The assignments
can be used to define an arbitrary number of intermediate results and
can refer to ``f(x, t)`` (the function being integrated, as a function of
``x``, the previous value of the state variable and ``t``, the time) and
``dt``, the size of the timestep.
For example, to define a Runge-Kutta 4 integrator (already provided as
`rk4`), use::
k1 = dt*f(x,t)
k2 = dt*f(x+k1/2,t+dt/2)
k3 = dt*f(x+k2/2,t+dt/2)
k4 = dt*f(x+k3,t+dt)
x_new = x+(k1+2*k2+2*k3+k4)/6
Note that for stochastic equations, the function `f` only corresponds to
the non-stochastic part of the equation. The additional function `g`
corresponds to the stochastic part that has to be multiplied with the
stochastic variable xi (a standard normal random variable -- if the
algorithm needs a random variable with a different variance/mean you have
to multiply/add it accordingly). Equations with more than one
stochastic variable do not have to be treated differently, the part
referring to ``g`` is repeated for all stochastic variables automatically.
Stochastic integrators can also make reference to ``dW`` (a normal
distributed random number with variance ``dt``) and ``g(x, t)``, the
stochastic part of an equation. A stochastic state updater could therefore
use a description like::
x_new = x + dt*f(x,t) + g(x, t) * dW
For simplicity, the same syntax is used for state updaters that only support
additive noise, even though ``g(x, t)`` does not depend on ``x`` or ``t``
in that case.
There a some restrictions on the complexity of the expressions (but most
can be worked around by using intermediate results as in the above Runge-
Kutta example): Every statement can only contain the functions ``f`` and
``g`` once; The expressions have to be linear in the functions, e.g. you
can use ``dt*f(x, t)`` but not ``f(x, t)**2``.
Parameters
----------
description : str
A state updater description (see above).
stochastic : {None, 'additive', 'multiplicative'}
What kind of stochastic equations this state updater supports: ``None``
means no support of stochastic equations, ``'additive'`` means only
equations with additive noise and ``'multiplicative'`` means
supporting arbitrary stochastic equations.
Raises
------
ValueError
If the parsing of the description failed.
Notes
-----
Since clocks are updated *after* the state update, the time ``t`` used
in the state update step is still at its previous value. Enumerating the
states and discrete times, ``x_new = x + dt*f(x, t)`` is therefore
understood as :math:`x_{i+1} = x_i + dt f(x_i, t_i)`, yielding the correct
forward Euler integration. If the integrator has to refer to the time at
the end of the timestep, simply use ``t + dt`` instead of ``t``.
See also
--------
euler, rk2, rk4, milstein
"""
# ===========================================================================
# Parsing definitions
# ===========================================================================
#: Legal names for temporary variables
TEMP_VAR = ~Literal("x_new") + Word(
f"{string.ascii_letters}_", f"{string.ascii_letters + string.digits}_"
).setResultsName("identifier")
#: A single expression
EXPRESSION = restOfLine.setResultsName("expression")
#: An assignment statement
STATEMENT = Group(TEMP_VAR + Suppress("=") + EXPRESSION).setResultsName("statement")
#: The last line of a state updater description
OUTPUT = Group(
Suppress(Literal("x_new")) + Suppress("=") + EXPRESSION
).setResultsName("output")
#: A complete state updater description
DESCRIPTION = ZeroOrMore(STATEMENT) + OUTPUT
def __init__(self, description, stochastic=None, custom_check=None):
self._description = description
self.stochastic = stochastic
self.custom_check = custom_check
try:
parsed = ExplicitStateUpdater.DESCRIPTION.parseString(
description, parseAll=True
)
except ParseException as p_exc:
ex = SyntaxError(f"Parsing failed: {str(p_exc.msg)}")
ex.text = str(p_exc.line)
ex.offset = p_exc.column
ex.lineno = p_exc.lineno
raise ex
self.statements = []
self.symbols = SYMBOLS.copy()
for element in parsed:
expression = str_to_sympy(element.expression)
# Replace all symbols used in state updater expressions by unique
# names that cannot clash with user-defined variables or functions
expression = expression.subs(sympy.Function("f"), self.symbols["__f"])
expression = expression.subs(sympy.Function("g"), self.symbols["__g"])
symbols = list(expression.atoms(sympy.Symbol))
unique_symbols = []
for symbol in symbols:
if symbol.name == "dt":
unique_symbols.append(symbol)
else:
unique_symbols.append(_symbol(f"__{symbol.name}"))
for symbol, unique_symbol in zip(symbols, unique_symbols):
expression = expression.subs(symbol, unique_symbol)
self.symbols.update({symbol.name: symbol for symbol in unique_symbols})
if element.getName() == "statement":
self.statements.append((f"__{element.identifier}", expression))
elif element.getName() == "output":
self.output = expression
else:
raise AssertionError(f"Unknown element name: {element.getName()}")
def __repr__(self):
# recreate a description string
description = "\n".join([f"{var} = {expr}" for var, expr in self.statements])
if len(description):
description += "\n"
description += f"x_new = {str(self.output)}"
classname = self.__class__.__name__
return f"{classname}('''{description}''', stochastic={self.stochastic!r})"
def __str__(self):
s = f"{self.__class__.__name__}\n"
if len(self.statements) > 0:
s += "Intermediate statements:\n"
s += "\n".join(
[f"{var} = {sympy_to_str(expr)}" for var, expr in self.statements]
)
s += "\n"
s += "Output:\n"
s += sympy_to_str(self.output)
return s
def _latex(self, *args):
from sympy import Symbol, latex
s = [r"\begin{equation}"]
for var, expr in self.statements:
expr = expr.subs(Symbol("x"), Symbol("x_t"))
s.append(f"{latex(Symbol(var))} = {latex(expr)}\\\\")
expr = self.output.subs(Symbol("x"), "x_t")
s.append(f"x_{{t+1}} = {latex(expr)}")
s.append(r"\end{equation}")
return "\n".join(s)
def _repr_latex_(self):
return self._latex()
[docs] def replace_func(self, x, t, expr, temp_vars, eq_symbols, stochastic_variable=None):
"""
Used to replace a single occurance of ``f(x, t)`` or ``g(x, t)``:
`expr` is the non-stochastic (in the case of ``f``) or stochastic
part (``g``) of the expression defining the right-hand-side of the
differential equation describing `var`. It replaces the variable
`var` with the value given as `x` and `t` by the value given for
`t`. Intermediate variables will be replaced with the appropriate
replacements as well.
For example, in the `rk2` integrator, the second step involves the
calculation of ``f(k/2 + x, dt/2 + t)``. If `var` is ``v`` and
`expr` is ``-v / tau``, this will result in ``-(_k_v/2 + v)/tau``.
Note that this deals with only one state variable `var`, given as
an argument to the surrounding `_generate_RHS` function.
"""
try:
s_expr = str_to_sympy(str(expr))
except SympifyError as ex:
raise ValueError(f'Error parsing the expression "{expr}": {str(ex)}')
for var in eq_symbols:
# Generate specific temporary variables for the state variable,
# e.g. '_k_v' for the state variable 'v' and the temporary
# variable 'k'.
if stochastic_variable is None:
temp_var_replacements = {
self.symbols[temp_var]: _symbol(f"{temp_var}_{var}")
for temp_var in temp_vars
}
else:
temp_var_replacements = {
self.symbols[temp_var]: _symbol(
f"{temp_var}_{var}_{stochastic_variable}"
)
for temp_var in temp_vars
}
# In the expression given as 'x', replace 'x' by the variable
# 'var' and all the temporary variables by their
# variable-specific counterparts.
x_replacement = x.subs(self.symbols["__x"], eq_symbols[var])
x_replacement = x_replacement.subs(temp_var_replacements)
# Replace the variable `var` in the expression by the new `x`
# expression
s_expr = s_expr.subs(eq_symbols[var], x_replacement)
# If the expression given for t in the state updater description
# is not just "t" (or rather "__t"), then replace t in the
# equations by it, and replace "__t" by "t" afterwards.
if t != self.symbols["__t"]:
s_expr = s_expr.subs(SYMBOLS["t"], t)
s_expr = s_expr.replace(self.symbols["__t"], SYMBOLS["t"])
return s_expr
def _non_stochastic_part(
self,
eq_symbols,
non_stochastic,
non_stochastic_expr,
stochastic_variable,
temp_vars,
var,
):
non_stochastic_results = []
if stochastic_variable is None or len(stochastic_variable) == 0:
# Replace the f(x, t) part
replace_f = lambda x, t: self.replace_func(
x, t, non_stochastic, temp_vars, eq_symbols
)
non_stochastic_result = non_stochastic_expr.replace(
self.symbols["__f"], replace_f
)
# Replace x by the respective variable
non_stochastic_result = non_stochastic_result.subs(
self.symbols["__x"], eq_symbols[var]
)
# Replace intermediate variables
temp_var_replacements = {
self.symbols[temp_var]: _symbol(f"{temp_var}_{var}")
for temp_var in temp_vars
}
non_stochastic_result = non_stochastic_result.subs(temp_var_replacements)
non_stochastic_results.append(non_stochastic_result)
elif isinstance(stochastic_variable, str):
# Replace the f(x, t) part
replace_f = lambda x, t: self.replace_func(
x, t, non_stochastic, temp_vars, eq_symbols, stochastic_variable
)
non_stochastic_result = non_stochastic_expr.replace(
self.symbols["__f"], replace_f
)
# Replace x by the respective variable
non_stochastic_result = non_stochastic_result.subs(
self.symbols["__x"], eq_symbols[var]
)
# Replace intermediate variables
temp_var_replacements = {
self.symbols[temp_var]: _symbol(
f"{temp_var}_{var}_{stochastic_variable}"
)
for temp_var in temp_vars
}
non_stochastic_result = non_stochastic_result.subs(temp_var_replacements)
non_stochastic_results.append(non_stochastic_result)
else:
# Replace the f(x, t) part
replace_f = lambda x, t: self.replace_func(
x, t, non_stochastic, temp_vars, eq_symbols
)
non_stochastic_result = non_stochastic_expr.replace(
self.symbols["__f"], replace_f
)
# Replace x by the respective variable
non_stochastic_result = non_stochastic_result.subs(
self.symbols["__x"], eq_symbols[var]
)
# Replace intermediate variables
temp_var_replacements = {
self.symbols[temp_var]: reduce(
operator.add,
[_symbol(f"{temp_var}_{var}_{xi}") for xi in stochastic_variable],
)
for temp_var in temp_vars
}
non_stochastic_result = non_stochastic_result.subs(temp_var_replacements)
non_stochastic_results.append(non_stochastic_result)
return non_stochastic_results
def _stochastic_part(
self,
eq_symbols,
stochastic,
stochastic_expr,
stochastic_variable,
temp_vars,
var,
):
stochastic_results = []
if isinstance(stochastic_variable, str):
# Replace the g(x, t) part
replace_f = lambda x, t: self.replace_func(
x,
t,
stochastic.get(stochastic_variable, 0),
temp_vars,
eq_symbols,
stochastic_variable,
)
stochastic_result = stochastic_expr.replace(self.symbols["__g"], replace_f)
# Replace x by the respective variable
stochastic_result = stochastic_result.subs(
self.symbols["__x"], eq_symbols[var]
)
# Replace dW by the respective variable
stochastic_result = stochastic_result.subs(
self.symbols["__dW"], stochastic_variable
)
# Replace intermediate variables
temp_var_replacements = {
self.symbols[temp_var]: _symbol(
f"{temp_var}_{var}_{stochastic_variable}"
)
for temp_var in temp_vars
}
stochastic_result = stochastic_result.subs(temp_var_replacements)
stochastic_results.append(stochastic_result)
else:
for xi in stochastic_variable:
# Replace the g(x, t) part
replace_f = lambda x, t: self.replace_func(
x, t, stochastic.get(xi, 0), temp_vars, eq_symbols, xi # noqa: B023
)
stochastic_result = stochastic_expr.replace(
self.symbols["__g"], replace_f
)
# Replace x by the respective variable
stochastic_result = stochastic_result.subs(
self.symbols["__x"], eq_symbols[var]
)
# Replace dW by the respective variable
stochastic_result = stochastic_result.subs(self.symbols["__dW"], xi)
# Replace intermediate variables
temp_var_replacements = {
self.symbols[temp_var]: _symbol(f"{temp_var}_{var}_{xi}")
for temp_var in temp_vars
}
stochastic_result = stochastic_result.subs(temp_var_replacements)
stochastic_results.append(stochastic_result)
return stochastic_results
def _generate_RHS(
self,
eqs,
var,
eq_symbols,
temp_vars,
expr,
non_stochastic_expr,
stochastic_expr,
stochastic_variable=(),
):
"""
Helper function used in `__call__`. Generates the right hand side of
an abstract code statement by appropriately replacing f, g and t.
For example, given a differential equation ``dv/dt = -(v + I) / tau``
(i.e. `var` is ``v` and `expr` is ``(-v + I) / tau``) together with
the `rk2` step ``return x + dt*f(x + k/2, t + dt/2)``
(i.e. `non_stochastic_expr` is
``x + dt*f(x + k/2, t + dt/2)`` and `stochastic_expr` is ``None``),
produces ``v + dt*(-v - _k_v/2 + I + _k_I/2)/tau``.
"""
# Note: in the following we are silently ignoring the case that a
# state updater does not care about either the non-stochastic or the
# stochastic part of an equation. We do trust state updaters to
# correctly specify their own abilities (i.e. they do not claim to
# support stochastic equations but actually just ignore the stochastic
# part). We can't really check the issue here, as we are only dealing
# with one line of the state updater description. It is perfectly valid
# to write the euler update as:
# non_stochastic = dt * f(x, t)
# stochastic = dt**.5 * g(x, t) * xi
# return x + non_stochastic + stochastic
#
# In the above case, we'll deal with lines which do not define either
# the stochastic or the non-stochastic part.
non_stochastic, stochastic = expr.split_stochastic()
if non_stochastic_expr is not None:
# We do have a non-stochastic part in the state updater description
non_stochastic_results = self._non_stochastic_part(
eq_symbols,
non_stochastic,
non_stochastic_expr,
stochastic_variable,
temp_vars,
var,
)
else:
non_stochastic_results = []
if not (stochastic is None or stochastic_expr is None):
# We do have a stochastic part in the state
# updater description
stochastic_results = self._stochastic_part(
eq_symbols,
stochastic,
stochastic_expr,
stochastic_variable,
temp_vars,
var,
)
else:
stochastic_results = []
RHS = sympy.Number(0)
# All the parts (one non-stochastic and potentially more than one
# stochastic part) are combined with addition
for non_stochastic_result in non_stochastic_results:
RHS += non_stochastic_result
for stochastic_result in stochastic_results:
RHS += stochastic_result
return sympy_to_str(RHS)
[docs] def __call__(self, eqs, variables=None, method_options=None):
"""
Apply a state updater description to model equations.
Parameters
----------
eqs : `Equations`
The equations describing the model
variables: dict-like, optional
The `Variable` objects for the model. Ignored by the explicit
state updater.
method_options : dict, optional
Additional options to the state updater (not used at the moment
for the explicit state updaters).
Examples
--------
>>> from brian2 import *
>>> eqs = Equations('dv/dt = -v / tau : volt')
>>> print(euler(eqs))
_v = -dt*v/tau + v
v = _v
>>> print(rk4(eqs))
__k_1_v = -dt*v/tau
__k_2_v = -dt*(__k_1_v/2 + v)/tau
__k_3_v = -dt*(__k_2_v/2 + v)/tau
__k_4_v = -dt*(__k_3_v + v)/tau
_v = __k_1_v/6 + __k_2_v/3 + __k_3_v/3 + __k_4_v/6 + v
v = _v
"""
extract_method_options(method_options, {})
# Non-stochastic numerical integrators should work for all equations,
# except for stochastic equations
if eqs.is_stochastic and self.stochastic is None:
raise UnsupportedEquationsException(
"Cannot integrate stochastic equations with this state updater."
)
if self.custom_check:
self.custom_check(eqs, variables)
# The final list of statements
statements = []
stochastic_variables = eqs.stochastic_variables
# The variables for the intermediate results in the state updater
# description, e.g. the variable k in rk2
intermediate_vars = [var for var, expr in self.statements]
# A dictionary mapping all the variables in the equations to their
# sympy representations
eq_variables = {var: _symbol(var) for var in eqs.eq_names}
# Generate the random numbers for the stochastic variables
for stochastic_variable in stochastic_variables:
statements.append(f"{stochastic_variable} = dt**.5 * randn()")
substituted_expressions = eqs.get_substituted_expressions(variables)
# Process the intermediate statements in the stateupdater description
for intermediate_var, intermediate_expr in self.statements:
# Split the expression into a non-stochastic and a stochastic part
non_stochastic_expr, stochastic_expr = split_expression(intermediate_expr)
# Execute the statement by appropriately replacing the functions f
# and g and the variable x for every equation in the model.
# We use the model equations where the subexpressions have
# already been substituted into the model equations.
for var, expr in substituted_expressions:
for xi in stochastic_variables:
RHS = self._generate_RHS(
eqs,
var,
eq_variables,
intermediate_vars,
expr,
non_stochastic_expr,
stochastic_expr,
xi,
)
statements.append(f"{intermediate_var}_{var}_{xi} = {RHS}")
if not stochastic_variables: # no stochastic variables
RHS = self._generate_RHS(
eqs,
var,
eq_variables,
intermediate_vars,
expr,
non_stochastic_expr,
stochastic_expr,
)
statements.append(f"{intermediate_var}_{var} = {RHS}")
# Process the "return" line of the stateupdater description
non_stochastic_expr, stochastic_expr = split_expression(self.output)
if eqs.is_stochastic and (
self.stochastic != "multiplicative"
and eqs.stochastic_type == "multiplicative"
):
# The equations are marked as having multiplicative noise and the
# current state updater does not support such equations. However,
# it is possible that the equations do not use multiplicative noise
# at all. They could depend on time via a function that is constant
# over a single time step (most likely, a TimedArray). In that case
# we can integrate the equations
dt_value = variables["dt"].get_value()[0] if "dt" in variables else None
for _, expr in substituted_expressions:
_, stoch = expr.split_stochastic()
if stoch is None:
continue
# There could be more than one stochastic variable (e.g. xi_1, xi_2)
for _, stoch_expr in stoch.items():
sympy_expr = str_to_sympy(stoch_expr.code)
# The equation really has multiplicative noise, if it depends
# on time (and not only via a function that is constant
# over dt), or if it depends on another variable defined
# via differential equations.
if not is_constant_over_dt(sympy_expr, variables, dt_value) or len(
stoch_expr.identifiers & eqs.diff_eq_names
):
raise UnsupportedEquationsException(
"Cannot integrate "
"equations with "
"multiplicative noise with "
"this state updater."
)
# Assign a value to all the model variables described by differential
# equations
for var, expr in substituted_expressions:
RHS = self._generate_RHS(
eqs,
var,
eq_variables,
intermediate_vars,
expr,
non_stochastic_expr,
stochastic_expr,
stochastic_variables,
)
statements.append(f"_{var} = {RHS}")
# Assign everything to the final variables
for var, _ in substituted_expressions:
statements.append(f"{var} = _{var}")
return "\n".join(statements)
# ===============================================================================
# Excplicit state updaters
# ===============================================================================
# these objects can be used like functions because they are callable
#: Forward Euler state updater
euler = ExplicitStateUpdater(
"x_new = x + dt * f(x,t) + g(x,t) * dW", stochastic="additive"
)
#: Second order Runge-Kutta method (midpoint method)
rk2 = ExplicitStateUpdater(
"""
k = dt * f(x,t)
x_new = x + dt*f(x + k/2, t + dt/2)"""
)
#: Classical Runge-Kutta method (RK4)
rk4 = ExplicitStateUpdater(
"""
k_1 = dt*f(x,t)
k_2 = dt*f(x+k_1/2,t+dt/2)
k_3 = dt*f(x+k_2/2,t+dt/2)
k_4 = dt*f(x+k_3,t+dt)
x_new = x+(k_1+2*k_2+2*k_3+k_4)/6
"""
)
[docs]def diagonal_noise(equations, variables):
"""
Checks whether we deal with diagonal noise, i.e. one independent noise
variable per variable.
Raises
------
UnsupportedEquationsException
If the noise is not diagonal.
"""
if not equations.is_stochastic:
return
stochastic_vars = []
for _, expr in equations.get_substituted_expressions(variables):
expr_stochastic_vars = expr.stochastic_variables
if len(expr_stochastic_vars) > 1:
# More than one stochastic variable --> no diagonal noise
raise UnsupportedEquationsException(
"Cannot integrate stochastic "
"equations with non-diagonal "
"noise with this state "
"updater."
)
stochastic_vars.extend(expr_stochastic_vars)
# If there's no stochastic variable is used in more than one equation, we
# have diagonal noise
if len(stochastic_vars) != len(set(stochastic_vars)):
raise UnsupportedEquationsException(
"Cannot integrate stochastic "
"equations with non-diagonal "
"noise with this state "
"updater."
)
#: Derivative-free Milstein method
milstein = ExplicitStateUpdater(
"""
x_support = x + dt*f(x, t) + dt**.5 * g(x, t)
g_support = g(x_support, t)
k = 1/(2*dt**.5)*(g_support - g(x, t))*(dW**2)
x_new = x + dt*f(x,t) + g(x, t) * dW + k
""",
stochastic="multiplicative",
custom_check=diagonal_noise,
)
#: Stochastic Heun method (for multiplicative Stratonovic SDEs with non-diagonal
#: diffusion matrix)
heun = ExplicitStateUpdater(
"""
x_support = x + g(x,t) * dW
g_support = g(x_support,t+dt)
x_new = x + dt*f(x,t) + .5*dW*(g(x,t)+g_support)
""",
stochastic="multiplicative",
)