Example: modelfitting_sbi

Note

You can launch an interactive, editable version of this example without installing any local files using the Binder service (although note that at some times this may be slow or fail to open): launchbinder

Model fitting with simulation-based inference

In this example, a HH-type model is used to demonstrate simulation-based inference with the sbi toolbox (https://www.mackelab.org/sbi/). It is based on a fake current-clamp recording generated from the same model that we use in the inference process. Two of the parameters (the maximum sodium and potassium conductances) are considered parameters of the model.

For more details about this approach, see the references below.

To run this example, you need to install the sbi package, e.g. with:

pip install sbi

References:

import matplotlib.pyplot as plt

from brian2 import *
import sbi.utils
import sbi.analysis
import sbi.inference
import torch  # PyTorch

defaultclock.dt = 0.05*ms

def simulate(params, I=1*nA, t_on=50*ms, t_total=350*ms):
    """
    Simulates the HH-model with Brian2 for parameter sets in params and the
    given input current (injection of I between t_on and t_total-t_on).

    Returns a dictionary {'t': time steps, 'v': voltage,
                          'I_inj': current, 'spike_count': spike count}.
    """
    assert t_total > 2*t_on
    t_off = t_total - t_on

    params = np.atleast_2d(params)
    # fixed parameters
    gleak = 10*nS
    Eleak = -70*mV
    VT = -60.0*mV
    C = 200*pF
    ENa = 53*mV
    EK = -107*mV

    # The conductance-based model
    eqs = '''
         dVm/dt = -(gNa*m**3*h*(Vm - ENa) + gK*n**4*(Vm - EK) + gleak*(Vm - Eleak) - I_inj) / C : volt
         I_inj = int(t >= t_on and t < t_off)*I : amp (shared)
         dm/dt = alpham*(1-m) - betam*m : 1
         dn/dt = alphan*(1-n) - betan*n : 1
         dh/dt = alphah*(1-h) - betah*h : 1

         alpham = (-0.32/mV) * (Vm - VT - 13.*mV) / (exp((-(Vm - VT - 13.*mV))/(4.*mV)) - 1)/ms : Hz
         betam = (0.28/mV) * (Vm - VT - 40.*mV) / (exp((Vm - VT - 40.*mV)/(5.*mV)) - 1)/ms : Hz

         alphah = 0.128 * exp(-(Vm - VT - 17.*mV) / (18.*mV))/ms : Hz
         betah = 4/(1 + exp((-(Vm - VT - 40.*mV)) / (5.*mV)))/ms : Hz

         alphan = (-0.032/mV) * (Vm - VT - 15.*mV) / (exp((-(Vm - VT - 15.*mV)) / (5.*mV)) - 1)/ms : Hz
         betan = 0.5*exp(-(Vm - VT - 10.*mV) / (40.*mV))/ms : Hz
         # The parameters to fit
         gNa : siemens (constant)
         gK : siemens (constant)
         '''
    neurons = NeuronGroup(params.shape[0], eqs, threshold='m>0.5', refractory='m>0.5',
                          method='exponential_euler', name='neurons')
    Vm_mon = StateMonitor(neurons, 'Vm', record=True, name='Vm_mon')
    spike_mon = SpikeMonitor(neurons, record=False, name='spike_mon')  #record=False → do not record times
    neurons.gNa_ = params[:, 0]*uS
    neurons.gK = params[:, 1]*uS

    neurons.Vm = 'Eleak'
    neurons.m = '1/(1 + betam/alpham)'         # Would be the solution when dm/dt = 0
    neurons.h = '1/(1 + betah/alphah)'         # Would be the solution when dh/dt = 0
    neurons.n = '1/(1 + betan/alphan)'         # Would be the solution when dn/dt = 0

    run(t_total)
    # For convenient plotting, reconstruct the current
    I_inj = ((Vm_mon.t >= t_on) & (Vm_mon.t < t_off))*I
    return dict(v=Vm_mon.Vm,
                t=Vm_mon.t,
                I_inj=I_inj,
                spike_count=spike_mon.count)


def calculate_summary_statistics(x):
    """Calculate summary statistics for results in x"""
    I_inj = x["I_inj"]
    v = x["v"]/mV

    spike_count = x["spike_count"]
    # Mean and standard deviation during stimulation
    v_active = v[:, I_inj > 0*nA]
    mean_active = np.mean(v_active, axis=1)
    std_active = np.std(v_active, axis=1)
    # Height of action potential peaks
    max_v = np.max(v_active, axis=1)

    # concatenation of summary statistics
    sum_stats = np.vstack((spike_count, mean_active, std_active, max_v))

    return sum_stats.T


def simulation_wrapper(params):
    """
    Returns summary statistics from conductance values in `params`.
    Summarizes the output of the simulation and converts it to `torch.Tensor`.
    """
    obs = simulate(params)
    summstats = torch.as_tensor(calculate_summary_statistics(obs))
    return summstats.to(torch.float32)


if __name__ == '__main__':
    # Define prior distribution over parameters
    prior_min = [.5, 1e-4]  # (gNa, gK) in µS
    prior_max = [80.,15.]
    prior = sbi.utils.torchutils.BoxUniform(low=torch.as_tensor(prior_min),
                                            high=torch.as_tensor(prior_max))

    # Simulate samples from the prior distribution
    theta = prior.sample((10_000,))
    print('Simulating samples from prior simulation... ', end='')
    stats = simulation_wrapper(theta.numpy())
    print('done.')

    # Train inference network
    density_estimator_build_fun = sbi.utils.posterior_nn(model='mdn')
    inference = sbi.inference.SNPE(prior,
                                   density_estimator=density_estimator_build_fun)
    print('Training inference network... ')
    inference.append_simulations(theta, stats).train()
    posterior = inference.build_posterior()

    # true parameters for real ground truth data
    true_params = np.array([[32., 1.]])
    true_data = simulate(true_params)
    t = true_data['t']
    I_inj = true_data['I_inj']
    v = true_data['v']
    xo = calculate_summary_statistics(true_data)
    print("The true summary statistics are:  ", xo)

    # Plot estimated posterior distribution
    samples = posterior.sample((1000,), x=xo, show_progress_bars=False)
    labels_params = [r'$\overline{g}_{Na}$', r'$\overline{g}_{K}$']
    sbi.analysis.pairplot(samples,
                          limits=[[.5, 80], [1e-4, 15.]],
                          ticks=[[.5, 80], [1e-4, 15.]],
                          figsize=(4, 4),
                          points=true_params, labels=labels_params,
                          points_offdiag={'markersize': 6},
                          points_colors=['r'])
    plt.tight_layout()

    # Draw a single sample from the posterior and convert to numpy for plotting.
    posterior_sample = posterior.sample((1,), x=xo,
                                        show_progress_bars=False).numpy()
    x = simulate(posterior_sample)

    # plot observation and sample
    fig, ax = plt.subplots(figsize=(8, 4))
    ax.plot(t/ms, v[0, :]/mV, lw=2, label='observation')
    ax.plot(t/ms, x['v'][0, :]/mV, '--', lw=2, label='posterior sample')
    ax.legend()
    ax.set(xlabel='time (ms)', ylabel='voltage (mV)')
    plt.show()
../_images/advanced.modelfitting_sbi.1.png ../_images/advanced.modelfitting_sbi.2.png