NEST implementation of the aeif models

Hans Ekkehard Plesser and Tanguy Fardet, 2016-09-09

This notebook provides a reference solution for the Adaptive Exponential Integrate and Fire (AEIF) neuronal model and compares it with several numerical implementation using simpler solvers. In particular this justifies the change of implementation in September 2016 to make the simulation closer to the reference solution.

Position of the problem

Basics

The equations governing the evolution of the AEIF model are

$$\left\lbrace\begin{array}{rcl} C_m\dot{V} &=& -g_L(V-E_L) + g_L \Delta_T e^{\frac{V-V_T}{\Delta_T}} + I_e + I_s(t) -w\\ \tau_s\dot{w} &=& a(V-E_L) - w \end{array}\right.$$

when $V < V_{peak}$ (threshold/spike detection). Once a spike occurs, we apply the reset conditions:

$$V = V_r \quad \text{and} \quad w = w + b$$

Divergence

In the AEIF model, the spike is generated by the exponential divergence. In practice, this means that just before threshold crossing (threshpassing), the argument of the exponential can become very large.

This can lead to numerical overflow or numerical instabilities in the solver, all the more if $V_{peak}$ is large, or if $\Delta_T$ is small.

Tested solutions

Old implementation (before September 2016)

The orginal solution that was adopted was to bind the exponential argument to be smaller that 10 (ad hoc value to be close to the original implementation in BRIAN). As will be shown in the notebook, this solution does not converge to the reference LSODAR solution.

New implementation

The new implementation does not bind the argument of the exponential, but the potential itself, since according to the theoretical model, $V$ should never get larger than $V_{peak}$. We will show that this solution is not only closer to the reference solution in general, but also converges towards it as the timestep gets smaller.

Reference solution

The reference solution is implemented using the LSODAR solver which is described and compared in the following references:

Technical details and requirements

Implementation of the functions

  • The old and new implementations are reproduced using Scipy and are called by the scipy_aeif function
  • The NEST implementations are not shown here, but keep in mind that for a given time resolution, they are closer to the reference result than the scipy implementation since the GSL implementation uses a RK45 adaptive solver.
  • The reference solution using LSODAR, called reference_aeif, is implemented through the assimulo package.

Requirements

To run this notebook, you need:


In [1]:
import numpy as np
from scipy.integrate import odeint
import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams['figure.figsize'] = (15, 6)

Scipy functions mimicking the NEST code

Right hand side functions


In [2]:
def rhs_aeif_new(y, _, p):
    '''
    New implementation bounding V < V_peak
    
    Parameters
    ----------
    y : list
        Vector containing the state variables [V, w]
    _ : unused var
    p : Params instance
        Object containing the neuronal parameters.
        
    Returns
    -------
    dv : double
        Derivative of V
    dw : double
        Derivative of w
    '''
    v = min(y[0], p.Vpeak)
    w = y[1]
    Ispike = 0.
    
    if p.DeltaT != 0.:
        Ispike = p.gL * p.DeltaT * np.exp((v-p.vT)/p.DeltaT)
        
    dv = (-p.gL*(v-p.EL) + Ispike - w + p.Ie)/p.Cm
    dw = (p.a * (v-p.EL) - w) / p.tau_w
    
    return dv, dw


def rhs_aeif_old(y, _, p):
    '''
    Old implementation bounding the argument of the
    exponential function (e_arg < 10.).
    
    Parameters
    ----------
    y : list
        Vector containing the state variables [V, w]
    _ : unused var
    p : Params instance
        Object containing the neuronal parameters.
        
    Returns
    -------
    dv : double
        Derivative of V
    dw : double
        Derivative of w
    '''
    v = y[0]
    w = y[1]
    Ispike = 0.
    
    if p.DeltaT != 0.:
        e_arg = min((v-p.vT)/p.DeltaT, 10.)
        Ispike = p.gL * p.DeltaT * np.exp(e_arg)
        
    dv = (-p.gL*(v-p.EL) + Ispike - w + p.Ie)/p.Cm
    dw = (p.a * (v-p.EL) - w) / p.tau_w
    
    return dv, dw

Complete model


In [3]:
def scipy_aeif(p, f, simtime, dt):
    '''
    Complete aeif model using scipy `odeint` solver.
    
    Parameters
    ----------
    p : Params instance
        Object containing the neuronal parameters.
    f : function
        Right-hand side function (either `rhs_aeif_old`
        or `rhs_aeif_new`)
    simtime : double
        Duration of the simulation (will run between
        0 and tmax)
    dt : double
        Time increment.
        
    Returns
    -------
    t : list
        Times at which the neuronal state was evaluated.
    y : list
        State values associated to the times in `t`
    s : list
        Spike times.
    vs : list
        Values of `V` just before the spike.
    ws : list
        Values of `w` just before the spike
    fos : list
        List of dictionaries containing additional output
        information from `odeint`
    '''
    t = np.arange(0, simtime, dt)   # time axis
    n = len(t)                 
    y = np.zeros((n, 2))         # V, w
    y[0, 0] = p.EL               # Initial: (V_0, w_0) = (E_L, 5.)
    y[0, 1] = 5.                 # Initial: (V_0, w_0) = (E_L, 5.)
    s = []      # spike times                 
    vs = []     # membrane potential at spike before reset
    ws = []     # w at spike before step
    fos = []    # full output dict from odeint()
    
    # imitate NEST: update time-step by time-step
    for k in range(1, n):
        
        # solve ODE from t_k-1 to t_k
        d, fo = odeint(f, y[k-1, :], t[k-1:k+1], (p, ), full_output=True)
        y[k, :] = d[1, :]
        fos.append(fo)
        
        # check for threshold crossing
        if y[k, 0] >= p.Vpeak:
            s.append(t[k])
            vs.append(y[k, 0])
            ws.append(y[k, 1])
            
            y[k, 0] = p.Vreset  # reset
            y[k, 1] += p.b      # step
            
    return t, y, s, vs, ws, fos

LSODAR reference solution

Setting assimulo class


In [4]:
from assimulo.solvers import LSODAR
from assimulo.problem import Explicit_Problem

class Extended_Problem(Explicit_Problem):

    # need variables here for access
    sw0 = [ False ]
    ts_spikes = []
    ws_spikes = []
    Vs_spikes = []
    
    def __init__(self, p):
        self.p = p
        self.y0 = [self.p.EL, 5.]   # V, w
        # reset variables
        self.ts_spikes = []
        self.ws_spikes = []
        self.Vs_spikes = []

    #The right-hand-side function (rhs)

    def rhs(self, t, y, sw):
        """
        This is the function we are trying to simulate (aeif model).
        """
        V, w = y[0], y[1]
        Ispike = 0.
        
        if self.p.DeltaT != 0.:
            Ispike = self.p.gL * self.p.DeltaT * np.exp((V-self.p.vT)/self.p.DeltaT)
        dotV = ( -self.p.gL*(V-self.p.EL) + Ispike + self.p.Ie - w ) / self.p.Cm
        dotW = ( self.p.a*(V-self.p.EL) - w ) / self.p.tau_w
        return np.array([dotV, dotW])

    # Sets a name to our function
    name = 'AEIF_nosyn'

    # The event function
    def state_events(self, t, y, sw):
        """
        This is our function that keeps track of our events. When the sign
        of any of the events has changed, we have an event.
        """
        event_0 = -5 if y[0] >= self.p.Vpeak else 5 # spike
        if event_0 < 0:
            if not self.ts_spikes:
                self.ts_spikes.append(t)
                self.Vs_spikes.append(y[0])
                self.ws_spikes.append(y[1])
            elif self.ts_spikes and not np.isclose(t, self.ts_spikes[-1], 0.01):
                self.ts_spikes.append(t)
                self.Vs_spikes.append(y[0])
                self.ws_spikes.append(y[1])
        return np.array([event_0])

    #Responsible for handling the events.
    def handle_event(self, solver, event_info):
        """
        Event handling. This functions is called when Assimulo finds an event as
        specified by the event functions.
        """
        ev = event_info
        event_info = event_info[0] # only look at the state events information.
        if event_info[0] > 0:
            solver.sw[0] = True
            solver.y[0] = self.p.Vreset
            solver.y[1] += self.p.b
        else:
            solver.sw[0] = False

    def initialize(self, solver):
        solver.h_sol=[]
        solver.nq_sol=[]

    def handle_result(self, solver, t, y):
        Explicit_Problem.handle_result(self, solver, t, y)
        # Extra output for algorithm analysis
        if solver.report_continuously:
           h, nq = solver.get_algorithm_data()
           solver.h_sol.extend([h])
           solver.nq_sol.extend([nq])

LSODAR reference model


In [5]:
def reference_aeif(p, simtime):
    '''
    Reference aeif model using LSODAR.
    
    Parameters
    ----------
    p : Params instance
        Object containing the neuronal parameters.
    f : function
        Right-hand side function (either `rhs_aeif_old`
        or `rhs_aeif_new`)
    simtime : double
        Duration of the simulation (will run between
        0 and tmax)
    dt : double
        Time increment.
        
    Returns
    -------
    t : list
        Times at which the neuronal state was evaluated.
    y : list
        State values associated to the times in `t`
    s : list
        Spike times.
    vs : list
        Values of `V` just before the spike.
    ws : list
        Values of `w` just before the spike
    h : list
        List of the minimal time increment at each step.
    '''
    #Create an instance of the problem
    exp_mod = Extended_Problem(p) #Create the problem
    exp_sim = LSODAR(exp_mod) #Create the solver

    exp_sim.atol=1.e-8
    exp_sim.report_continuously = True
    exp_sim.store_event_points = True

    exp_sim.verbosity = 30

    #Simulate
    t, y = exp_sim.simulate(simtime) #Simulate 10 seconds
    
    return t, y, exp_mod.ts_spikes, exp_mod.Vs_spikes, exp_mod.ws_spikes, exp_sim.h_sol

Set the parameters and simulate the models

Params (chose a dictionary)


In [6]:
# Regular spiking
aeif_param = {
    'V_reset': -58.,
    'V_peak': 0.0,
    'V_th': -50.,
    'I_e': 420.,
    'g_L': 11.,
    'tau_w': 300.,
    'E_L': -70.,
    'Delta_T': 2.,
    'a': 3.,
    'b': 0.,
    'C_m': 200.,
    'V_m': -70., #! must be equal to E_L
    'w': 5., #! must be equal to 5.
    'tau_syn_ex': 0.2
}

# Bursting
aeif_param2 = {
    'V_reset': -46.,
    'V_peak': 0.0,
    'V_th': -50.,
    'I_e': 500.0,
    'g_L': 10.,
    'tau_w': 120.,
    'E_L': -58.,
    'Delta_T': 2.,
    'a': 2.,
    'b': 100.,
    'C_m': 200.,
    'V_m': -58., #! must be equal to E_L
    'w': 5., #! must be equal to 5.
}

# Close to chaos (use resol < 0.005 and simtime = 200)
aeif_param3 = {
    'V_reset': -48.,
    'V_peak': 0.0,
    'V_th': -50.,
    'I_e': 160.,
    'g_L': 12.,
    'tau_w': 130.,
    'E_L': -60.,
    'Delta_T': 2.,
    'a': -11.,
    'b': 30.,
    'C_m': 100.,
    'V_m': -60., #! must be equal to E_L
    'w': 5., #! must be equal to 5.
}

class Params(object):
    '''
    Class giving access to the neuronal
    parameters.
    '''
    def __init__(self):
        self.params = aeif_param
        self.Vpeak = aeif_param["V_peak"]
        self.Vreset = aeif_param["V_reset"]
        self.gL = aeif_param["g_L"]
        self.Cm = aeif_param["C_m"]
        self.EL = aeif_param["E_L"]
        self.DeltaT = aeif_param["Delta_T"]
        self.tau_w = aeif_param["tau_w"]
        self.a = aeif_param["a"]
        self.b = aeif_param["b"]
        self.vT = aeif_param["V_th"]
        self.Ie = aeif_param["I_e"]
    
p = Params()

Simulate the 3 implementations


In [7]:
# Parameters of the simulation
simtime = 100.
resol = 0.01

t_old, y_old, s_old, vs_old, ws_old, fo_old = scipy_aeif(p, rhs_aeif_old, simtime, resol)
t_new, y_new, s_new, vs_new, ws_new, fo_new = scipy_aeif(p, rhs_aeif_new, simtime, resol)
t_ref, y_ref, s_ref, vs_ref, ws_ref, h_ref = reference_aeif(p, simtime)


Final Run Statistics: AEIF_nosyn 

 Number of steps                       : 2013
 Number of function evaluations        : 5590
 Number of Jacobian evaluations        : 0
 Number of state function evaluations  : 2042
 Number of state events                : 7

Solver options:

 Solver                  : LSODAR 
 Absolute tolerances     : [  1.00000000e-08   1.00000000e-08]
 Relative tolerances     : 1e-06
 Starter                 : classical

Simulation interval    : 0.0 - 100.0 seconds.
Elapsed simulation time: 0.09047099999999997 seconds.

Plot the results

Zoom out


In [8]:
fig, ax = plt.subplots()
ax2 = ax.twinx()

# Plot the potentials
ax.plot(t_ref, y_ref[:,0], linestyle="-", label="V ref.")
ax.plot(t_old, y_old[:,0], linestyle="-.", label="V old")
ax.plot(t_new, y_new[:,0], linestyle="--", label="V new")

# Plot the adaptation variables
ax2.plot(t_ref, y_ref[:,1], linestyle="-", c="k", label="w ref.")
ax2.plot(t_old, y_old[:,1], linestyle="-.", c="m", label="w old")
ax2.plot(t_new, y_new[:,1], linestyle="--", c="y", label="w new")

# Show
ax.set_xlim([0., simtime])
ax.set_ylim([-65., 40.])
ax.set_xlabel("Time (ms)")
ax.set_ylabel("V (mV)")
ax2.set_ylim([-20., 20.])
ax2.set_ylabel("w (pA)")
ax.legend(loc=6)
ax2.legend(loc=2)
plt.show()


Zoom in


In [9]:
fig, ax = plt.subplots()
ax2 = ax.twinx()

# Plot the potentials
ax.plot(t_ref, y_ref[:,0], linestyle="-", label="V ref.")
ax.plot(t_old, y_old[:,0], linestyle="-.", label="V old")
ax.plot(t_new, y_new[:,0], linestyle="--", label="V new")

# Plot the adaptation variables
ax2.plot(t_ref, y_ref[:,1], linestyle="-", c="k", label="w ref.")
ax2.plot(t_old, y_old[:,1], linestyle="-.", c="y", label="w old")
ax2.plot(t_new, y_new[:,1], linestyle="--", c="m", label="w new")

ax.set_xlim([90., 92.])
ax.set_ylim([-65., 40.])
ax.set_xlabel("Time (ms)")
ax.set_ylabel("V (mV)")
ax2.set_ylim([17.5, 18.5])
ax2.set_ylabel("w (pA)")
ax.legend(loc=5)
ax2.legend(loc=2)
plt.show()


Compare properties at spike times


In [10]:
print("spike times:\n-----------")
print("ref", np.around(s_ref, 3)) # ref lsodar
print("old", np.around(s_old, 3))
print("new", np.around(s_new, 3))

print("\nV at spike time:\n---------------")
print("ref", np.around(vs_ref, 3)) # ref lsodar
print("old", np.around(vs_old, 3))
print("new", np.around(vs_new, 3))

print("\nw at spike time:\n---------------")
print("ref", np.around(ws_ref, 3)) # ref lsodar
print("old", np.around(ws_old, 3))
print("new", np.around(ws_new, 3))


spike times:
-----------
ref [ 18.715  30.561  42.495  54.517  66.626  78.819  91.096]
old [ 18.73  30.59  42.54  54.58  66.71  78.92  91.22]
new [ 18.72  30.57  42.51  54.54  66.66  78.86  91.14]

V at spike time:
---------------
ref [ 0.006  0.03   0.025  0.036  0.033  0.031  0.041]
old [  6.128   5.615   6.107  10.186  17.895   4.997  20.766]
new [ 32413643.009  32591616.327  35974587.741  51016349.639  77907589.627
  37451353.637  11279320.151]

w at spike time:
---------------
ref [  7.359   9.328  11.235  13.08   14.864  16.589  18.256]
old [  7.367   9.344  11.258  13.111  14.906  16.637  18.315]
new [  7.362   9.334  11.244  13.093  14.883  16.611  18.278]

Size of minimal integration timestep


In [11]:
plt.semilogy(t_ref, h_ref, label='Reference')
plt.semilogy(t_old[1:], [d['hu'] for d in fo_old], linewidth=2, label='Old')
plt.semilogy(t_new[1:], [d['hu'] for d in fo_new], label='New')

plt.legend(loc=6)
plt.show();


Convergence towards LSODAR reference with step size

Zoom out


In [12]:
plt.plot(t_ref, y_ref[:,0], label="V ref.")
resolutions = (0.1, 0.01, 0.001)
di_res = {}

for resol in resolutions:
    t_old, y_old, _, _, _, _ = scipy_aeif(p, rhs_aeif_old, simtime, resol)
    t_new, y_new, _, _, _, _ = scipy_aeif(p, rhs_aeif_new, simtime, resol)
    di_res[resol] = (t_old, y_old, t_new, y_new)
    plt.plot(t_old, y_old[:,0], linestyle=":", label="V old, r={}".format(resol))
    plt.plot(t_new, y_new[:,0], linestyle="--", linewidth=1.5, label="V new, r={}".format(resol))
plt.xlim(0., simtime)
plt.xlabel("Time (ms)")
plt.ylabel("V (mV)")
plt.legend(loc=2)
plt.show();


Zoom in


In [13]:
plt.plot(t_ref, y_ref[:,0], label="V ref.")
for resol in resolutions:
    t_old, y_old = di_res[resol][:2]
    t_new, y_new = di_res[resol][2:]
    plt.plot(t_old, y_old[:,0], linestyle="--", label="V old, r={}".format(resol))
    plt.plot(t_new, y_new[:,0], linestyle="-.", linewidth=2., label="V new, r={}".format(resol))
plt.xlim(90., 92.)
plt.ylim([-62., 2.])
plt.xlabel("Time (ms)")
plt.ylabel("V (mV)")
plt.legend(loc=2)
plt.show();