aeif
modelsThis notebook provides a reference solution for the Adaptive Exponential Integrate and Fire (AEIF) neuronal model and compares it with several numerical implementation using simpler solvers. In particular this justifies the change of implementation in September 2016 to make the simulation closer to the reference solution.
The equations governing the evolution of the AEIF model are
$$\left\lbrace\begin{array}{rcl} C_m\dot{V} &=& -g_L(V-E_L) + g_L \Delta_T e^{\frac{V-V_T}{\Delta_T}} + I_e + I_s(t) -w\\ \tau_s\dot{w} &=& a(V-E_L) - w \end{array}\right.$$when $V < V_{peak}$ (threshold/spike detection). Once a spike occurs, we apply the reset conditions:
$$V = V_r \quad \text{and} \quad w = w + b$$In the AEIF model, the spike is generated by the exponential divergence. In practice, this means that just before threshold crossing (threshpassing), the argument of the exponential can become very large.
This can lead to numerical overflow or numerical instabilities in the solver, all the more if $V_{peak}$ is large, or if $\Delta_T$ is small.
The orginal solution that was adopted was to bind the exponential argument to be smaller that 10 (ad hoc value to be close to the original implementation in BRIAN). As will be shown in the notebook, this solution does not converge to the reference LSODAR solution.
The new implementation does not bind the argument of the exponential, but the potential itself, since according to the theoretical model, $V$ should never get larger than $V_{peak}$. We will show that this solution is not only closer to the reference solution in general, but also converges towards it as the timestep gets smaller.
The reference solution is implemented using the LSODAR solver which is described and compared in the following references:
scipy_aeif
functionreference_aeif
, is implemented through the assimulo package.To run this notebook, you need:
In [1]:
import numpy as np
from scipy.integrate import odeint
import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams['figure.figsize'] = (15, 6)
In [2]:
def rhs_aeif_new(y, _, p):
'''
New implementation bounding V < V_peak
Parameters
----------
y : list
Vector containing the state variables [V, w]
_ : unused var
p : Params instance
Object containing the neuronal parameters.
Returns
-------
dv : double
Derivative of V
dw : double
Derivative of w
'''
v = min(y[0], p.Vpeak)
w = y[1]
Ispike = 0.
if p.DeltaT != 0.:
Ispike = p.gL * p.DeltaT * np.exp((v-p.vT)/p.DeltaT)
dv = (-p.gL*(v-p.EL) + Ispike - w + p.Ie)/p.Cm
dw = (p.a * (v-p.EL) - w) / p.tau_w
return dv, dw
def rhs_aeif_old(y, _, p):
'''
Old implementation bounding the argument of the
exponential function (e_arg < 10.).
Parameters
----------
y : list
Vector containing the state variables [V, w]
_ : unused var
p : Params instance
Object containing the neuronal parameters.
Returns
-------
dv : double
Derivative of V
dw : double
Derivative of w
'''
v = y[0]
w = y[1]
Ispike = 0.
if p.DeltaT != 0.:
e_arg = min((v-p.vT)/p.DeltaT, 10.)
Ispike = p.gL * p.DeltaT * np.exp(e_arg)
dv = (-p.gL*(v-p.EL) + Ispike - w + p.Ie)/p.Cm
dw = (p.a * (v-p.EL) - w) / p.tau_w
return dv, dw
In [3]:
def scipy_aeif(p, f, simtime, dt):
'''
Complete aeif model using scipy `odeint` solver.
Parameters
----------
p : Params instance
Object containing the neuronal parameters.
f : function
Right-hand side function (either `rhs_aeif_old`
or `rhs_aeif_new`)
simtime : double
Duration of the simulation (will run between
0 and tmax)
dt : double
Time increment.
Returns
-------
t : list
Times at which the neuronal state was evaluated.
y : list
State values associated to the times in `t`
s : list
Spike times.
vs : list
Values of `V` just before the spike.
ws : list
Values of `w` just before the spike
fos : list
List of dictionaries containing additional output
information from `odeint`
'''
t = np.arange(0, simtime, dt) # time axis
n = len(t)
y = np.zeros((n, 2)) # V, w
y[0, 0] = p.EL # Initial: (V_0, w_0) = (E_L, 5.)
y[0, 1] = 5. # Initial: (V_0, w_0) = (E_L, 5.)
s = [] # spike times
vs = [] # membrane potential at spike before reset
ws = [] # w at spike before step
fos = [] # full output dict from odeint()
# imitate NEST: update time-step by time-step
for k in range(1, n):
# solve ODE from t_k-1 to t_k
d, fo = odeint(f, y[k-1, :], t[k-1:k+1], (p, ), full_output=True)
y[k, :] = d[1, :]
fos.append(fo)
# check for threshold crossing
if y[k, 0] >= p.Vpeak:
s.append(t[k])
vs.append(y[k, 0])
ws.append(y[k, 1])
y[k, 0] = p.Vreset # reset
y[k, 1] += p.b # step
return t, y, s, vs, ws, fos
In [4]:
from assimulo.solvers import LSODAR
from assimulo.problem import Explicit_Problem
class Extended_Problem(Explicit_Problem):
# need variables here for access
sw0 = [ False ]
ts_spikes = []
ws_spikes = []
Vs_spikes = []
def __init__(self, p):
self.p = p
self.y0 = [self.p.EL, 5.] # V, w
# reset variables
self.ts_spikes = []
self.ws_spikes = []
self.Vs_spikes = []
#The right-hand-side function (rhs)
def rhs(self, t, y, sw):
"""
This is the function we are trying to simulate (aeif model).
"""
V, w = y[0], y[1]
Ispike = 0.
if self.p.DeltaT != 0.:
Ispike = self.p.gL * self.p.DeltaT * np.exp((V-self.p.vT)/self.p.DeltaT)
dotV = ( -self.p.gL*(V-self.p.EL) + Ispike + self.p.Ie - w ) / self.p.Cm
dotW = ( self.p.a*(V-self.p.EL) - w ) / self.p.tau_w
return np.array([dotV, dotW])
# Sets a name to our function
name = 'AEIF_nosyn'
# The event function
def state_events(self, t, y, sw):
"""
This is our function that keeps track of our events. When the sign
of any of the events has changed, we have an event.
"""
event_0 = -5 if y[0] >= self.p.Vpeak else 5 # spike
if event_0 < 0:
if not self.ts_spikes:
self.ts_spikes.append(t)
self.Vs_spikes.append(y[0])
self.ws_spikes.append(y[1])
elif self.ts_spikes and not np.isclose(t, self.ts_spikes[-1], 0.01):
self.ts_spikes.append(t)
self.Vs_spikes.append(y[0])
self.ws_spikes.append(y[1])
return np.array([event_0])
#Responsible for handling the events.
def handle_event(self, solver, event_info):
"""
Event handling. This functions is called when Assimulo finds an event as
specified by the event functions.
"""
ev = event_info
event_info = event_info[0] # only look at the state events information.
if event_info[0] > 0:
solver.sw[0] = True
solver.y[0] = self.p.Vreset
solver.y[1] += self.p.b
else:
solver.sw[0] = False
def initialize(self, solver):
solver.h_sol=[]
solver.nq_sol=[]
def handle_result(self, solver, t, y):
Explicit_Problem.handle_result(self, solver, t, y)
# Extra output for algorithm analysis
if solver.report_continuously:
h, nq = solver.get_algorithm_data()
solver.h_sol.extend([h])
solver.nq_sol.extend([nq])
In [5]:
def reference_aeif(p, simtime):
'''
Reference aeif model using LSODAR.
Parameters
----------
p : Params instance
Object containing the neuronal parameters.
f : function
Right-hand side function (either `rhs_aeif_old`
or `rhs_aeif_new`)
simtime : double
Duration of the simulation (will run between
0 and tmax)
dt : double
Time increment.
Returns
-------
t : list
Times at which the neuronal state was evaluated.
y : list
State values associated to the times in `t`
s : list
Spike times.
vs : list
Values of `V` just before the spike.
ws : list
Values of `w` just before the spike
h : list
List of the minimal time increment at each step.
'''
#Create an instance of the problem
exp_mod = Extended_Problem(p) #Create the problem
exp_sim = LSODAR(exp_mod) #Create the solver
exp_sim.atol=1.e-8
exp_sim.report_continuously = True
exp_sim.store_event_points = True
exp_sim.verbosity = 30
#Simulate
t, y = exp_sim.simulate(simtime) #Simulate 10 seconds
return t, y, exp_mod.ts_spikes, exp_mod.Vs_spikes, exp_mod.ws_spikes, exp_sim.h_sol
In [6]:
# Regular spiking
aeif_param = {
'V_reset': -58.,
'V_peak': 0.0,
'V_th': -50.,
'I_e': 420.,
'g_L': 11.,
'tau_w': 300.,
'E_L': -70.,
'Delta_T': 2.,
'a': 3.,
'b': 0.,
'C_m': 200.,
'V_m': -70., #! must be equal to E_L
'w': 5., #! must be equal to 5.
'tau_syn_ex': 0.2
}
# Bursting
aeif_param2 = {
'V_reset': -46.,
'V_peak': 0.0,
'V_th': -50.,
'I_e': 500.0,
'g_L': 10.,
'tau_w': 120.,
'E_L': -58.,
'Delta_T': 2.,
'a': 2.,
'b': 100.,
'C_m': 200.,
'V_m': -58., #! must be equal to E_L
'w': 5., #! must be equal to 5.
}
# Close to chaos (use resol < 0.005 and simtime = 200)
aeif_param3 = {
'V_reset': -48.,
'V_peak': 0.0,
'V_th': -50.,
'I_e': 160.,
'g_L': 12.,
'tau_w': 130.,
'E_L': -60.,
'Delta_T': 2.,
'a': -11.,
'b': 30.,
'C_m': 100.,
'V_m': -60., #! must be equal to E_L
'w': 5., #! must be equal to 5.
}
class Params(object):
'''
Class giving access to the neuronal
parameters.
'''
def __init__(self):
self.params = aeif_param
self.Vpeak = aeif_param["V_peak"]
self.Vreset = aeif_param["V_reset"]
self.gL = aeif_param["g_L"]
self.Cm = aeif_param["C_m"]
self.EL = aeif_param["E_L"]
self.DeltaT = aeif_param["Delta_T"]
self.tau_w = aeif_param["tau_w"]
self.a = aeif_param["a"]
self.b = aeif_param["b"]
self.vT = aeif_param["V_th"]
self.Ie = aeif_param["I_e"]
p = Params()
In [7]:
# Parameters of the simulation
simtime = 100.
resol = 0.01
t_old, y_old, s_old, vs_old, ws_old, fo_old = scipy_aeif(p, rhs_aeif_old, simtime, resol)
t_new, y_new, s_new, vs_new, ws_new, fo_new = scipy_aeif(p, rhs_aeif_new, simtime, resol)
t_ref, y_ref, s_ref, vs_ref, ws_ref, h_ref = reference_aeif(p, simtime)
In [8]:
fig, ax = plt.subplots()
ax2 = ax.twinx()
# Plot the potentials
ax.plot(t_ref, y_ref[:,0], linestyle="-", label="V ref.")
ax.plot(t_old, y_old[:,0], linestyle="-.", label="V old")
ax.plot(t_new, y_new[:,0], linestyle="--", label="V new")
# Plot the adaptation variables
ax2.plot(t_ref, y_ref[:,1], linestyle="-", c="k", label="w ref.")
ax2.plot(t_old, y_old[:,1], linestyle="-.", c="m", label="w old")
ax2.plot(t_new, y_new[:,1], linestyle="--", c="y", label="w new")
# Show
ax.set_xlim([0., simtime])
ax.set_ylim([-65., 40.])
ax.set_xlabel("Time (ms)")
ax.set_ylabel("V (mV)")
ax2.set_ylim([-20., 20.])
ax2.set_ylabel("w (pA)")
ax.legend(loc=6)
ax2.legend(loc=2)
plt.show()
In [9]:
fig, ax = plt.subplots()
ax2 = ax.twinx()
# Plot the potentials
ax.plot(t_ref, y_ref[:,0], linestyle="-", label="V ref.")
ax.plot(t_old, y_old[:,0], linestyle="-.", label="V old")
ax.plot(t_new, y_new[:,0], linestyle="--", label="V new")
# Plot the adaptation variables
ax2.plot(t_ref, y_ref[:,1], linestyle="-", c="k", label="w ref.")
ax2.plot(t_old, y_old[:,1], linestyle="-.", c="y", label="w old")
ax2.plot(t_new, y_new[:,1], linestyle="--", c="m", label="w new")
ax.set_xlim([90., 92.])
ax.set_ylim([-65., 40.])
ax.set_xlabel("Time (ms)")
ax.set_ylabel("V (mV)")
ax2.set_ylim([17.5, 18.5])
ax2.set_ylabel("w (pA)")
ax.legend(loc=5)
ax2.legend(loc=2)
plt.show()
In [10]:
print("spike times:\n-----------")
print("ref", np.around(s_ref, 3)) # ref lsodar
print("old", np.around(s_old, 3))
print("new", np.around(s_new, 3))
print("\nV at spike time:\n---------------")
print("ref", np.around(vs_ref, 3)) # ref lsodar
print("old", np.around(vs_old, 3))
print("new", np.around(vs_new, 3))
print("\nw at spike time:\n---------------")
print("ref", np.around(ws_ref, 3)) # ref lsodar
print("old", np.around(ws_old, 3))
print("new", np.around(ws_new, 3))
In [11]:
plt.semilogy(t_ref, h_ref, label='Reference')
plt.semilogy(t_old[1:], [d['hu'] for d in fo_old], linewidth=2, label='Old')
plt.semilogy(t_new[1:], [d['hu'] for d in fo_new], label='New')
plt.legend(loc=6)
plt.show();
In [12]:
plt.plot(t_ref, y_ref[:,0], label="V ref.")
resolutions = (0.1, 0.01, 0.001)
di_res = {}
for resol in resolutions:
t_old, y_old, _, _, _, _ = scipy_aeif(p, rhs_aeif_old, simtime, resol)
t_new, y_new, _, _, _, _ = scipy_aeif(p, rhs_aeif_new, simtime, resol)
di_res[resol] = (t_old, y_old, t_new, y_new)
plt.plot(t_old, y_old[:,0], linestyle=":", label="V old, r={}".format(resol))
plt.plot(t_new, y_new[:,0], linestyle="--", linewidth=1.5, label="V new, r={}".format(resol))
plt.xlim(0., simtime)
plt.xlabel("Time (ms)")
plt.ylabel("V (mV)")
plt.legend(loc=2)
plt.show();
In [13]:
plt.plot(t_ref, y_ref[:,0], label="V ref.")
for resol in resolutions:
t_old, y_old = di_res[resol][:2]
t_new, y_new = di_res[resol][2:]
plt.plot(t_old, y_old[:,0], linestyle="--", label="V old, r={}".format(resol))
plt.plot(t_new, y_new[:,0], linestyle="-.", linewidth=2., label="V new, r={}".format(resol))
plt.xlim(90., 92.)
plt.ylim([-62., 2.])
plt.xlabel("Time (ms)")
plt.ylabel("V (mV)")
plt.legend(loc=2)
plt.show();