"""
Compartmental models.
This module defines the `SpatialNeuron` class, which defines multicompartmental
models.
"""
import copy
import weakref
import numpy as np
import sympy as sp
from brian2.core.variables import Variables
from brian2.equations.codestrings import Expression
from brian2.equations.equations import (
DIFFERENTIAL_EQUATION,
PARAMETER,
SUBEXPRESSION,
Equations,
SingleEquation,
extract_constant_subexpressions,
)
from brian2.groups.group import CodeRunner, Group
from brian2.groups.neurongroup import NeuronGroup, SubexpressionUpdater, to_start_stop
from brian2.groups.subgroup import Subgroup
from brian2.parsing.sympytools import str_to_sympy, sympy_to_str
from brian2.units.allunits import amp, meter, ohm, siemens, volt
from brian2.units.fundamentalunits import (
DimensionMismatchError,
Quantity,
fail_for_dimension_mismatch,
have_same_dimensions,
)
from brian2.units.stdunits import cm, uF
from brian2.utils.logger import get_logger
__all__ = ["SpatialNeuron"]
logger = get_logger(__name__)
[docs]
class FlatMorphology:
"""
Container object to store the flattened representation of a morphology.
Note that all values are stored as numpy arrays without unit information
(i.e. in base units).
"""
def __init__(self, morphology):
self.n = n = morphology.total_compartments # Total number of compartments
# Per-compartment attributes
self.length = np.zeros(n)
self.distance = np.zeros(n)
self.area = np.zeros(n)
self.diameter = np.zeros(n)
self.volume = np.zeros(n)
self.r_length_1 = np.zeros(n)
self.r_length_2 = np.zeros(n)
self.start_x = np.zeros(n)
self.start_y = np.zeros(n)
self.start_z = np.zeros(n)
self.x = np.zeros(n)
self.y = np.zeros(n)
self.z = np.zeros(n)
self.end_x = np.zeros(n)
self.end_y = np.zeros(n)
self.end_z = np.zeros(n)
self.depth = np.zeros(n, dtype=np.int32)
self.sections = sections = morphology.total_sections
self.end_distance = np.zeros(sections)
# Index of the parent for each section (-1 for the root)
self.morph_parent_i = np.zeros(sections, dtype=np.int32)
# The children indices for each section (list of lists, will be later
# transformed into an array representation)
self.morph_children = []
# each section is child of exactly one parent, this stores the index in
# the parents list of children
self.morph_idxchild = np.zeros(sections, dtype=np.int32)
self.starts = np.zeros(sections, dtype=np.int32)
self.ends = np.zeros(sections, dtype=np.int32)
# recursively fill the data structures
self._sections_without_coordinates = False
self.has_coordinates = False
self._offset = 0
self._section_counter = 0
self._insert_data(morphology)
if self.has_coordinates and self._sections_without_coordinates:
logger.info(
"The morphology has a mix of sections with and "
"without coordinates. The SpatialNeuron object "
"will store NaN values for the coordinates of "
"the sections that do not specify coordinates. "
"Call generate_coordinates on the morphology "
"before creating the SpatialNeuron object to fill "
"in the missing coordinates."
)
# Do not store coordinates for morphologies that don't define them
if not self.has_coordinates:
self.start_x = self.start_y = self.start_z = None
self.x = self.y = self.z = None
self.end_x = self.end_y = self.end_z = None
# Transform the list of list of children into a 2D array (stored as
# 1D) -- note that this wastes space if the number of children per
# section is very different. In practice, this should not be much of a
# problem since most sections have 0, 1, or 2 children (e.g. SWC files
# on neuromorpho.org are all binary trees)
self.morph_children_num = np.array([len(c) for c in self.morph_children] + [0])
max_children = max(self.morph_children_num)
morph_children = np.zeros((sections + 1, max_children), dtype=np.int32)
for idx, section_children in enumerate(self.morph_children):
morph_children[idx, : len(section_children)] = section_children
self.morph_children = morph_children.reshape(-1)
def _insert_data(self, section, parent_idx=-1, depth=0):
n = section.n
start = self._offset
end = self._offset + n
# Compartment attributes
self.depth[start:end] = depth
self.length[start:end] = np.asarray(section.length)
self.distance[start:end] = np.asarray(section.distance)
self.area[start:end] = np.asarray(section.area)
self.diameter[start:end] = np.asarray(section.diameter)
self.volume[start:end] = np.asarray(section.volume)
self.r_length_1[start:end] = np.asarray(section.r_length_1)
self.r_length_2[start:end] = np.asarray(section.r_length_2)
if section.x is None:
self._sections_without_coordinates = True
self.start_x[start:end] = np.ones(n) * np.nan
self.start_y[start:end] = np.ones(n) * np.nan
self.start_z[start:end] = np.ones(n) * np.nan
self.x[start:end] = np.ones(n) * np.nan
self.y[start:end] = np.ones(n) * np.nan
self.z[start:end] = np.ones(n) * np.nan
self.end_x[start:end] = np.ones(n) * np.nan
self.end_y[start:end] = np.ones(n) * np.nan
self.end_z[start:end] = np.ones(n) * np.nan
else:
self.has_coordinates = True
self.start_x[start:end] = np.asarray(section.start_x)
self.start_y[start:end] = np.asarray(section.start_y)
self.start_z[start:end] = np.asarray(section.start_z)
self.x[start:end] = np.asarray(section.x)
self.y[start:end] = np.asarray(section.y)
self.z[start:end] = np.asarray(section.z)
self.end_x[start:end] = np.asarray(section.end_x)
self.end_y[start:end] = np.asarray(section.end_y)
self.end_z[start:end] = np.asarray(section.end_z)
# Section attributes
idx = self._section_counter
# We start counting from 1 for the parent indices, since the index 0
# is used for the (virtual) root compartment
self.morph_parent_i[idx] = parent_idx + 1
self.morph_children.append([])
self.starts[idx] = start
self.ends[idx] = end
# Append ourselves to the children list of our parent
self.morph_idxchild[idx] = len(self.morph_children[parent_idx + 1])
self.morph_children[parent_idx + 1].append(idx + 1)
self.end_distance[idx] = section.end_distance
# Recurse down the tree
self._offset += n
self._section_counter += 1
for child in section.children:
self._insert_data(child, parent_idx=idx, depth=depth + 1)
[docs]
class SpatialNeuron(NeuronGroup):
"""
A single neuron with a morphology and possibly many compartments.
Parameters
----------
morphology : `Morphology`
The morphology of the neuron.
model : str, `Equations`
The equations defining the group.
method : str, function, optional
The numerical integration method. Either a string with the name of a
registered method (e.g. "euler") or a function that receives an
`Equations` object and returns the corresponding abstract code. If no
method is specified, a suitable method will be chosen automatically.
threshold : str, optional
The condition which produces spikes. Should be a single line boolean
expression.
threshold_location : (int, `Morphology`), optional
Compartment where the threshold condition applies, specified as an
integer (compartment index) or a `Morphology` object corresponding to
the compartment (e.g. ``morpho.axon[10*um]``).
If unspecified, the threshold condition applies at all compartments.
Cm : `Quantity`, optional
Specific capacitance in uF/cm**2 (default 0.9). It can be accessed and
modified later as a state variable. In particular, its value can differ
in different compartments.
Ri : `Quantity`, optional
Intracellular resistivity in ohm.cm (default 150). It can be accessed
as a shared state variable, but modified only before the first run.
It is uniform across the neuron.
reset : str, optional
The (possibly multi-line) string with the code to execute on reset.
events : dict, optional
User-defined events in addition to the "spike" event defined by the
``threshold``. Has to be a mapping of strings (the event name) to
strings (the condition) that will be checked.
refractory : {str, `Quantity`}, optional
Either the length of the refractory period (e.g. ``2*ms``), a string
expression that evaluates to the length of the refractory period
after each spike (e.g. ``'(1 + rand())*ms'``), or a string expression
evaluating to a boolean value, given the condition under which the
neuron stays refractory after a spike (e.g. ``'v > -20*mV'``)
namespace : dict, optional
A dictionary mapping identifier names to objects. If not given, the
namespace will be filled in at the time of the call of `Network.run`,
with either the values from the ``namespace`` argument of the
`Network.run` method or from the local context, if no such argument is
given.
dtype : (`dtype`, `dict`), optional
The `numpy.dtype` that will be used to store the values, or a
dictionary specifying the type for variable names. If a value is not
provided for a variable (or no value is provided at all), the preference
setting `core.default_float_dtype` is used.
dt : `Quantity`, optional
The time step to be used for the simulation. Cannot be combined with
the `clock` argument.
clock : `Clock`, optional
The update clock to be used. If neither a clock, nor the `dt` argument
is specified, the `defaultclock` will be used.
order : int, optional
The priority of of this group for operations occurring at the same time
step and in the same scheduling slot. Defaults to 0.
name : str, optional
A unique name for the group, otherwise use ``spatialneuron_0``, etc.
"""
def __init__(
self,
morphology=None,
model=None,
threshold=None,
refractory=False,
reset=None,
events=None,
threshold_location=None,
dt=None,
clock=None,
order=0,
Cm=0.9 * uF / cm**2,
Ri=150 * ohm * cm,
name="spatialneuron*",
dtype=None,
namespace=None,
method=("exact", "exponential_euler", "rk2", "heun"),
method_options=None,
):
# #### Prepare and validate equations
if isinstance(model, str):
model = Equations(model)
if not isinstance(model, Equations):
raise TypeError(
"model has to be a string or an Equations "
f"object, is '{type(model)}' instead."
)
# Insert the threshold mechanism at the specified location
if threshold_location is not None:
if hasattr(threshold_location, "_indices"): # assuming this is a method
threshold_location = threshold_location._indices()
# for now, only a single compartment allowed
try:
int(threshold_location)
except TypeError:
raise AttributeError(
"Threshold can only be applied on a single location"
)
threshold = f"({threshold}) and (i == {str(threshold_location)})"
# Check flags (we have point currents)
model.check_flags(
{
DIFFERENTIAL_EQUATION: ("point current",),
PARAMETER: ("constant", "shared", "linked", "point current"),
SUBEXPRESSION: ("shared", "point current", "constant over dt"),
}
)
#: The original equations as specified by the user (i.e. before
#: inserting point-currents into the membrane equation, before adding
#: all the internally used variables and constants, etc.).
self.user_equations = model
# Separate subexpressions depending whether they are considered to be
# constant over a time step or not (this would also be done by the
# NeuronGroup initializer later, but this would give incorrect results
# for the linearity check)
model, constant_over_dt = extract_constant_subexpressions(model)
# Extract membrane equation
if "Im" in model:
if len(model["Im"].flags):
raise TypeError(
"Cannot specify any flags for the transmembrane current 'Im'."
)
membrane_expr = model["Im"].expr # the membrane equation
else:
raise TypeError("The transmembrane current 'Im' must be defined")
model_equations = []
# Insert point currents in the membrane equation
for eq in model.values():
if eq.varname == "Im":
continue # ignore -- handled separately
if "point current" in eq.flags:
fail_for_dimension_mismatch(
eq.dim, amp, f"Point current {eq.varname} should be in amp"
)
membrane_expr = Expression(
f"{str(membrane_expr.code)}+{eq.varname}/area"
)
eq = SingleEquation(
eq.type,
eq.varname,
eq.dim,
expr=eq.expr,
flags=list(set(eq.flags) - {"point current"}),
)
model_equations.append(eq)
model_equations.append(
SingleEquation(
SUBEXPRESSION,
"Im",
dimensions=(amp / meter**2).dim,
expr=membrane_expr,
)
)
model_equations.append(SingleEquation(PARAMETER, "v", volt.dim))
model = Equations(model_equations)
###### Process model equations (Im) to extract total conductance and the remaining current
# Expand expressions in the membrane equation
for var, expr in model.get_substituted_expressions(include_subexpressions=True):
if var == "Im":
Im_expr = expr
break
else:
raise AssertionError("Model equations did not contain Im!")
# Differentiate Im with respect to v
Im_sympy_exp = str_to_sympy(Im_expr.code)
v_sympy = sp.Symbol("v", real=True)
diffed = sp.diff(Im_sympy_exp, v_sympy)
unevaled_derivatives = diffed.atoms(sp.Derivative)
if len(unevaled_derivatives):
raise TypeError(
f"Cannot take the derivative of '{Im_expr.code}' with respect to v."
)
gtot_str = sympy_to_str(sp.simplify(-diffed))
I0_str = sympy_to_str(sp.simplify(Im_sympy_exp - diffed * v_sympy))
if gtot_str == "0":
gtot_str += "*siemens/meter**2"
if I0_str == "0":
I0_str += "*amp/meter**2"
gtot_str = f"gtot__private={gtot_str}: siemens/meter**2"
I0_str = f"I0__private={I0_str}: amp/meter**2"
model += Equations(f"{gtot_str}\n{I0_str}")
# Insert morphology (store a copy)
self.morphology = copy.deepcopy(morphology)
# Flatten the morphology
self.flat_morphology = FlatMorphology(morphology)
# Equations for morphology
# TODO: check whether Cm and Ri are already in the equations
# no: should be shared instead of constant
# yes: should be constant (check)
eqs_constants = Equations(
"""
length : meter (constant)
distance : meter (constant)
area : meter**2 (constant)
volume : meter**3
Ic : amp/meter**2
diameter : meter (constant)
Cm : farad/meter**2 (constant)
Ri : ohm*meter (constant, shared)
r_length_1 : meter (constant)
r_length_2 : meter (constant)
time_constant = Cm/gtot__private : second
space_constant = (2/pi)**(1.0/3.0) * (area/(1/r_length_1 + 1/r_length_2))**(1.0/6.0) /
(2*(Ri*gtot__private)**(1.0/2.0)) : meter
"""
)
if self.flat_morphology.has_coordinates:
eqs_constants += Equations(
"""
x : meter (constant)
y : meter (constant)
z : meter (constant)
"""
)
NeuronGroup.__init__(
self,
morphology.total_compartments,
model=model + eqs_constants,
method_options=method_options,
threshold=threshold,
refractory=refractory,
reset=reset,
events=events,
method=method,
dt=dt,
clock=clock,
order=order,
namespace=namespace,
dtype=dtype,
name=name,
)
# Parameters and intermediate variables for solving the cable equations
# Note that some of these variables could have meaningful physical
# units (e.g. _v_star is in volt, _I0_all is in amp/meter**2 etc.) but
# since these variables should never be used in user code, we don't
# assign them any units
self.variables.add_arrays(
[
"_ab_star0",
"_ab_star1",
"_ab_star2",
"_b_plus",
"_b_minus",
"_v_star",
"_u_plus",
"_u_minus",
"_v_previous",
"_c",
# The following two are only necessary for
# C code where we cannot deal with scalars
# and arrays interchangeably:
"_I0_all",
"_gtot_all",
],
size=self.N,
read_only=True,
)
self.Cm = Cm
self.Ri = Ri
# These explict assignments will load the morphology values from disk
# in standalone mode
self.distance_ = self.flat_morphology.distance
self.length_ = self.flat_morphology.length
self.area_ = self.flat_morphology.area
self.diameter_ = self.flat_morphology.diameter
self.volume_ = self.flat_morphology.volume
self.r_length_1_ = self.flat_morphology.r_length_1
self.r_length_2_ = self.flat_morphology.r_length_2
if self.flat_morphology.has_coordinates:
self.x_ = self.flat_morphology.x
self.y_ = self.flat_morphology.y
self.z_ = self.flat_morphology.z
# Performs numerical integration step
self.add_attribute("diffusion_state_updater")
self.diffusion_state_updater = SpatialStateUpdater(
self, method, clock=self.clock, order=order
)
# Update v after the gating variables to obtain consistent Ic and Im
self.diffusion_state_updater.order = 1
# Creation of contained_objects that do the work
self.contained_objects.extend([self.diffusion_state_updater])
if len(constant_over_dt):
self.subexpression_updater = SubexpressionUpdater(self, constant_over_dt)
self.contained_objects.append(self.subexpression_updater)
def __getattr__(self, name):
"""
Subtrees are accessed by attribute, e.g. neuron.axon.
"""
return self.spatialneuron_attribute(self, name)
def __getitem__(self, item):
"""
Selects a segment, where x is a slice of either compartment
indexes or distances.
Note a: segment is not a SpatialNeuron, only a Group.
"""
return self.spatialneuron_segment(self, item)
@staticmethod
def _find_subtree_end(morpho):
"""
Go down a morphology recursively to find the (absolute) index of the
"final" compartment (i.e. the one with the highest index) of the
subtree.
Parameters
----------
morpho : `Morphology`
The morphology for which to find the index.
Returns
-------
index : int
The highest index within the subtree.
"""
indices = [morpho.indices[-1]]
for child in morpho.children:
indices.append(SpatialNeuron._find_subtree_end(child))
return max(indices)
[docs]
@staticmethod
def spatialneuron_attribute(neuron, name):
"""
Selects a subtree from `SpatialNeuron` neuron and returns a `SpatialSubgroup`.
If it does not exist, returns the `Group` attribute.
"""
if name == "main": # Main section, without the subtrees
indices = neuron.morphology.indices[:]
start, stop = indices[0], indices[-1]
return SpatialSubgroup(
neuron, start, stop + 1, morphology=neuron.morphology
)
elif (name != "morphology") and (
(name in getattr(neuron.morphology, "children", []))
or all([c in "LR123456789" for c in name])
): # subtree
morpho = neuron.morphology[name]
start = morpho.indices[0]
stop = SpatialNeuron._find_subtree_end(morpho)
return SpatialSubgroup(neuron, start, stop + 1, morphology=morpho)
else:
return Group.__getattr__(neuron, name)
[docs]
@staticmethod
def spatialneuron_segment(neuron, item):
"""
Selects a segment from `SpatialNeuron` neuron, where item is a slice of
either compartment indexes or distances.
Note a: segment is not a `SpatialNeuron`, only a `Group`.
"""
if isinstance(item, slice) and isinstance(item.start, Quantity):
if item.step is not None:
raise ValueError(
"Cannot specify a step size for slicing basedon length."
)
start, stop = item.start, item.stop
if not have_same_dimensions(start, meter) or not have_same_dimensions(
stop, meter
):
raise DimensionMismatchError(
"Start and stop should have units of meter", start, stop
)
# Convert to integers (compartment numbers)
indices = neuron.morphology.indices[item]
start, stop = indices[0], indices[-1] + 1
elif not isinstance(item, slice) and hasattr(item, "indices"):
start, stop = to_start_stop(item.indices[:], neuron._N)
else:
start, stop = to_start_stop(item, neuron._N)
if isinstance(neuron, SpatialSubgroup):
start += neuron.start
stop += neuron.start
if start >= stop:
raise IndexError(
f"Illegal start/end values for subgroup, {int(start)}>={int(stop)}"
)
if isinstance(neuron, SpatialSubgroup):
# Note that the start/stop values calculated above are always
# absolute values, even for subgroups
neuron = neuron.source
return Subgroup(neuron, start, stop)
[docs]
class SpatialSubgroup(Subgroup):
"""
A subgroup of a `SpatialNeuron`.
Parameters
----------
source : int
First compartment.
stop : int
Ending compartment, not included (as in slices).
morphology : `Morphology`
Morphology corresponding to the subgroup (not the full
morphology).
name : str, optional
Name of the subgroup.
"""
def __init__(self, source, start, stop, morphology, name=None):
self.morphology = morphology
if isinstance(source, SpatialSubgroup):
source = source.source
start += source.start
stop += source.start
Subgroup.__init__(self, source, start, stop, name)
def __getattr__(self, name):
return SpatialNeuron.spatialneuron_attribute(self, name)
def __getitem__(self, item):
return SpatialNeuron.spatialneuron_segment(self, item)
[docs]
class SpatialStateUpdater(CodeRunner, Group):
"""
The `CodeRunner` that updates the state variables of a `SpatialNeuron`
at every timestep.
"""
def __init__(self, group, method, clock, order=0):
# group is the neuron (a group of compartments)
self.method_choice = method
self.group = weakref.proxy(group)
compartments = group.flat_morphology.n
sections = group.flat_morphology.sections
CodeRunner.__init__(
self,
group,
"spatialstateupdate",
code="""_gtot = gtot__private
_I0 = I0__private""",
clock=clock,
when="groups",
order=order,
name=f"{group.name}_spatialstateupdater*",
check_units=False,
template_kwds={"number_sections": sections},
)
self.variables = Variables(self, default_index="_section_idx")
self.variables.add_reference("N", group)
# One value per compartment
self.variables.add_arange("_compartment_idx", size=compartments)
self.variables.add_array(
"_invr",
dimensions=siemens.dim,
size=compartments,
constant=True,
index="_compartment_idx",
)
# one value per section
self.variables.add_arange("_section_idx", size=sections)
self.variables.add_array(
"_P_parent", size=sections, constant=True
) # elements below diagonal
self.variables.add_arrays(
["_morph_idxchild", "_morph_parent_i", "_starts", "_ends"],
size=sections,
dtype=np.int32,
constant=True,
)
self.variables.add_arrays(
["_invr0", "_invrn"], dimensions=siemens.dim, size=sections, constant=True
)
# one value per section + 1 value for the root
self.variables.add_arange("_section_root_idx", size=sections + 1)
self.variables.add_array(
"_P_diag", size=sections + 1, constant=True, index="_section_root_idx"
)
self.variables.add_array(
"_B", size=sections + 1, constant=True, index="_section_root_idx"
)
self.variables.add_array(
"_morph_children_num",
size=sections + 1,
dtype=np.int32,
constant=True,
index="_section_root_idx",
)
# 2D matrices of size (sections + 1) x max children per section
self.variables.add_arange(
"_morph_children_idx", size=len(group.flat_morphology.morph_children)
)
self.variables.add_array(
"_P_children",
size=len(group.flat_morphology.morph_children),
index="_morph_children_idx",
constant=True,
) # elements above diagonal
self.variables.add_array(
"_morph_children",
size=len(group.flat_morphology.morph_children),
dtype=np.int32,
constant=True,
index="_morph_children_idx",
)
self._enable_group_attributes()
self._morph_parent_i = group.flat_morphology.morph_parent_i
self._morph_children_num = group.flat_morphology.morph_children_num
self._morph_children = group.flat_morphology.morph_children
self._morph_idxchild = group.flat_morphology.morph_idxchild
self._starts = group.flat_morphology.starts
self._ends = group.flat_morphology.ends