In [1]:
%load_ext Cython
In [2]:
import numpy as np
import pulse2percept as p2p
%matplotlib inline
import matplotlib.pyplot as plt
In [3]:
class LegacyNanduri2012(p2p.retina.Nanduri2012):
"""Preserve old implementation to make sure Cython model runs correctly"""
def __init__(self, **kwargs):
# Set default values of keyword arguments
self.tau1 = 0.42 / 1000
self.tau2 = 45.25 / 1000
self.tau3 = 26.25 / 1000
self.eps = 8.73
self.asymptote = 14.0
self.slope = 3.0
self.shift = 16.0
# Overwrite any given keyword arguments, print warning message (True)
# if attempting to set an unrecognized keyword
self.set_kwargs(True, **kwargs)
_, self.gamma1 = p2p.utils.gamma(1, self.tau1, self.tsample)
_, self.gamma2 = p2p.utils.gamma(1, self.tau2, self.tsample)
_, self.gamma3 = p2p.utils.gamma(3, self.tau3, self.tsample)
def model_cascade(self, in_arr, pt_list, layers, use_jit):
"""Nanduri model cascade
Parameters
----------
in_arr: array - like
A 2D array specifying the effective current values
at a particular spatial location(pixel); one value
per retinal layer and electrode.
Dimensions: < # layers x #electrodes>
pt_list: list
List of pulse train 'data' containers.
Dimensions: < # electrodes x #time points>
layers: list
List of retinal layers to simulate.
Choose from:
- 'OFL': optic fiber layer
- 'GCL': ganglion cell layer
use_jit: bool
If True, applies just - in-time(JIT) compilation to
expensive computations for additional speed - up
(requires Numba).
"""
if 'INL' in layers:
raise ValueError("The Nanduri2012 model does not support an inner "
"nuclear layer.")
# Although the paper says to use cathodic-first, the code only
# reproduces if we use what we now call anodic-first. So flip the sign
# on the stimulus here:
b1 = -self.calc_layer_current(in_arr, pt_list)
# Fast response
b2 = self.tsample * p2p.utils.conv(b1, self.gamma1, mode='full',
method='sparse',
use_jit=use_jit)[:b1.size]
# Charge accumulation
ca = self.tsample * np.cumsum(np.maximum(0, b1))
ca = self.tsample * p2p.utils.conv(ca, self.gamma2, mode='full',
method='fft')[:b1.size]
b3 = np.maximum(0, b2 - self.eps * ca)
# Stationary nonlinearity
b3max = b3.max()
sigmoid = ss.expit((b3max - self.shift) / self.slope)
b4 = b3 / b3max * sigmoid * self.asymptote
# Slow response
b5 = self.tsample * p2p.utils.conv(b4, self.gamma3, mode='full',
method='fft')[:b1.size]
return p2p.utils.TimeSeries(self.tsample, b5)
In [4]:
import scipy.special as ss
def finite_diff(stim, model, maxR3=99.8873446571, skip_i=0):
tmp_chargeacc = 0
tmp_ca = 0
tmp_cl = 0
tmp_R1 = 0
tmp_R2 = 0
tmp_R3norm = 0
dt = stim.tsample
tmp_R3 = 0
sc_fac = 0
tmp_R4a = [0, 0, 0, 0]
out_t = np.arange(0, stim.duration, stim.tsample)
out_R4 = np.zeros_like(out_t)
for i in range(len(out_t)):
tmp_R1 += dt * (-stim.data[i] - tmp_R1) / model.tau1
# leaky integrated charge accumulation
tmp_chargeacc += dt * np.maximum(stim.data[i], 0)
tmp_ca += dt * (tmp_chargeacc - tmp_ca) / model.tau2
tmp_R3 = np.maximum(tmp_R1 - model.eps * tmp_ca, 0)
# nonlinearity
sc_fac = model.asymptote * ss.expit((maxR3 - model.shift) / model.slope)
# R4: R3 passed through a cascade of 3 leaky integrators
tmp_R4a[0] = tmp_R3 / maxR3 * sc_fac
for j in range(3):
dR4a = dt * (tmp_R4a[j] - tmp_R4a[j + 1]) / model.tau3
tmp_R4a[j + 1] += dR4a
out_R4[i] = tmp_R4a[-1]
return out_t, out_R4
In [5]:
%%cython
import numpy as np
cimport numpy as np
import scipy.special as ss
import cython
cdef inline float float_max(float a, float b): return a if a >= b else b
DTYPE = np.float
ctypedef np.float_t DTYPE_t
def cythoncascade(stim, model, maxR3=99.8873446571, skip_i=0):
cdef float tmp_chargeacc = 0
cdef float tmp_ca = 0
cdef float tmp_cl = 0
cdef float tmp_R1 = 0
cdef float tmp_R2 = 0
cdef float tmp_R3norm = 0
cdef float dt = stim.tsample
cdef np.ndarray[DTYPE_t] stimdata = stim.data
cdef float tmp_R3 = 0
cdef float sc_fac = 0
tmp_R4a = [0, 0, 0, 0]
cdef np.ndarray[DTYPE_t] out_t = np.arange(0, stim.duration, stim.tsample, dtype=DTYPE)
cdef np.ndarray[DTYPE_t] out_R4 = np.zeros_like(out_t, dtype=DTYPE)
cdef float tau1 = model.tau1
cdef float tau2 = model.tau2
cdef float tau3 = model.tau3
cdef float asymptote = model.asymptote
cdef float shift = model.shift
cdef float slope = model.slope
for i in range(len(out_t)):
tmp_R1 += dt * (-stimdata[i] - tmp_R1) / tau1
# leaky integrated charge accumulation
tmp_chargeacc += dt * float_max(stimdata[i], 0)
tmp_ca += dt * (tmp_chargeacc - tmp_ca) / tau2
tmp_R3 = float_max(tmp_R1 - model.eps * tmp_ca, 0)
# nonlinearity
sc_fac = asymptote * ss.expit((maxR3 - shift) / slope)
# R4: R3 passed through a cascade of 3 leaky integrators
tmp_R4a[0] = tmp_R3b = tmp_R3 / maxR3 * sc_fac
for j in range(3):
dR4a = dt * (tmp_R4a[j] - tmp_R4a[j + 1]) / tau3
tmp_R4a[j + 1] += dR4a
out_R4[i] = tmp_R4a[-1]
return out_t, out_R4
In [6]:
tsample = 0.005 / 1000
stim = p2p.stimuli.PulseTrain(tsample, freq=20, amp=150, pulse_dur=0.45 / 1000, dur=0.5)
ecm = np.array([1, 1]).reshape((2, 1))
layers = ['GCL']
use_jit = True
nanduri = p2p.retina.Nanduri2012(tsample=tsample)
nanduri_out = nanduri.model_cascade(ecm, [stim.data], layers, use_jit)
nanduri_t = tsample * np.arange(len(nanduri_out.data))
legacy = LegacyNanduri2012(tsample=tsample)
legacy_out = legacy.model_cascade(ecm, [stim.data], layers, use_jit)
legacy_t = tsample * np.arange(len(legacy_out.data))
In [7]:
finite_diff_t, finite_diff_out = finite_diff(stim, nanduri)
In [8]:
cython_t, cython_out = cythoncascade(stim, nanduri)
In [9]:
plt.figure(figsize=(10, 5))
plt.plot(legacy_t, legacy_out.data, linewidth=10, label='Nanduri 2012 (Legacy)')
plt.plot(nanduri_t, nanduri_out.data, linewidth=5, label='Nanduri 2012 (Cython)')
plt.plot(cython_t, cython_out, linewidth=2, label='Finite difference (Cython)')
plt.legend(loc='lower right')
Out[9]:
In [10]:
np.allclose(legacy_out.data, nanduri_out.data, atol=5e-3)
Out[10]:
In [11]:
np.allclose(cython_out, nanduri_out.data, atol=5e-3)
Out[11]:
Convolutions (pure Python):
In [13]:
%timeit out = legacy.model_cascade(ecm, [stim.data], layers, use_jit)
Finite difference model (pure Python):
In [14]:
%timeit out = finite_diff(stim, nanduri)
Finite difference model (naive Cython):
In [15]:
%timeit out_R4 = cythoncascade(stim, nanduri)
Finite difference model (pulse2percept):
In [16]:
%timeit out = nanduri.model_cascade(ecm, [stim.data], layers, use_jit)
In [ ]: