Note. The following notebook contains code in addition to text and figures. By default, the code has been hidden. You can click the icon that looks like an eye in the toolbar above to show the code. To run the code, click the cell menu, then "run all".
In this notebook we investigate the effect of deafferentation on the model. Specifically, what happens when we reduce the number of fires, but through a homeostatic process, restore the original chopper cell firing rate? For simplicity, we only cover the case of excitatory inputs only (no inhibition).
The interactive figure below shows the number of inputs on the x-axis, and the output firing rate on the y-axis. The colour indicates the CV at that point. What is happening is that we modify the synaptic weight $w$ and input firing rate $\rho$ in order to achieve the range of output firing rates. Specifically, we have a pair of "reference" firing rates. By default, we always want an input firing rate of 200 sp/s to achieve an output firing rate of 100 sp/s (although this can be modified in the interactive figure). We use this reference pair to set the synaptic weight $w$, and then modify the input firing rate $\rho$ to cover the range of output firing rates.
Cautionary notes:
In [ ]:
# Imports etc.
%matplotlib inline
from brian2 import *
from functools import partial
import ipywidgets as ipw
from matplotlib import cm
import matplotlib
import matplotlib.patheffects as PathEffects
from scipy.interpolate import Rbf
import joblib
import warnings
warnings.filterwarnings("ignore")
defaultclock.dt = 0.05*ms
mem = joblib.Memory(cachedir="joblib", verbose=0)
# Later on we generate a bunch of values at points that aren't on a
# grid: we use this to interpolate to a grid so we can display as an
# image. We use scipy's Radial basis function interpolation method.
def interpolate_to_image(x, y, c, xmin, xmax, xn, ymin, ymax, yn, smooth=1.0):
X = linspace(xmin, xmax, xn)
Y = linspace(ymin, ymax, yn)
cx, cy = meshgrid(X, Y)
f = Rbf(x, y, c, function='linear', smooth=smooth)
cc = f(cx, cy)
return cc
# Utility function
def normed(x, ref=None, logscale=False):
if ref is None:
ref = x
x = x*1.0
if isinstance(x, ndarray):
x = array(x)
ref = array(ref)
if logscale:
x = log(x)
ref = log(ref)
return (x-amin(ref))/(amax(ref)-amin(ref))
# The model itself: this is copied directly from level_dependence.ipynb
thenet = None
@mem.cache
def get_steady_state_data(
repeats=1000,
duration=250*ms,
skip=50*ms,
mu=4.0,
sigma=0.1,
tau=5*ms,
refractory=0*ms,
):
global thenet
if thenet is None or len(thenet['G'])!=repeats:
eqs = '''
dv/dt = (mu-v)/tau+sigma*tau**-0.5*xi : 1 (unless refractory)
refrac : second
'''
G = NeuronGroup(repeats, eqs, threshold='v>1', reset='v=0', refractory='refrac',
name='G', method='euler')
M = SpikeMonitor(G, name='M')
thenet = Network(G, M)
thenet.store()
else:
G = thenet['G']
M = thenet['M']
thenet.restore()
G.refrac = refractory
G.not_refractory = True
G.lastspike = -inf*second
G.v = 0
ns = {'mu': mu, 'tau': tau, 'sigma': sigma}
M.active = False
thenet.run(skip, namespace=ns)
M.active = True
thenet.run(duration-skip, namespace=ns)
trains = M.spike_trains()
dtrains = [diff(train) for train in trains.values() if len(train)>1]
if len(dtrains):
isi = hstack(dtrains)*second
else:
isi = array([])
if len(isi)>1:
cv = std(isi)/mean(isi)
else:
cv = nan
rate = len(M.t)/(repeats*(duration-skip))
return cv, rate
# This function finds the points on a particular column, i.e. the points
# for a fixed number N=num_anf of input neurons. Given a target output
# firing rate and input firing rate, it finds the weight w that gives
# that output rate for that input rate. Then, given that weight, it
# modifies the input firing rate up and down until the full range of
# output firing rates has been covered. The rate at which it moves up
# and down is a parameter (search_fraction).
@mem.cache
def find_points(num_anf, cn_rate_min, cn_rate_max, target_cn_rate, target_anf_rate, search_fraction,
refractory=0.6*ms, tau=6*ms,
repeats=50, duration=150*ms, skip=50*ms, spikes_per_repeat=100,
):
ssd = partial(get_steady_state_data, repeats=repeats, duration=duration, skip=skip,
refractory=refractory, tau=tau)
K = num_anf*target_anf_rate*tau
# step 1: search for w giving cn_rate. This is a stochastic search but we just treat it as if
# it were not and hope that works well enough.
# Start with a guess based on the case when sigma = 0
# mu = w*K, sigma=w*sqrt(K)
# predicted rate R = -1/(tau*log(1-1/mu))
# so want mu=1/(1-exp(-1/(tau*R)))=w*K
w = 1.0/(1.0-exp(-1/(tau*target_cn_rate)))/K
cv, rate = ssd(mu=w*K, sigma=w*sqrt(K))
# now search in the direction we would need to move
w_prev, rate_prev = w, rate
#while 1:
for _ in range(100):
dw = w*search_fraction*sign(target_cn_rate-rate)
w = w_prev+dw
_, rate = ssd(mu=w*K, sigma=w*sqrt(K))
if ((rate_prev<target_cn_rate and rate>target_cn_rate) or
(rate_prev>target_cn_rate and rate<target_cn_rate)):
# changed from above to below, now linearly interpolate
break
w_prev = w
rate_prev = rate
w = w_prev+(target_cn_rate-rate_prev)*(w-w_prev)/(rate-rate_prev)
_, cn_rate = ssd(mu=w*K, sigma=w*sqrt(K))
# Step 2:
# At this point we should have a good value of w, now we just need to search
# up and down varying anf_rate to get to min and max values of cn_rate, storing everything as
# we go
all_anf_rate = [target_anf_rate]
all_cv = [cv]
all_cn_rate = [cn_rate]
anf_rate = target_anf_rate
while cn_rate<cn_rate_max:
anf_rate = anf_rate+anf_rate*search_fraction
K = num_anf*anf_rate*tau
cv, cn_rate = ssd(mu=w*K, sigma=w*sqrt(K))
all_anf_rate.append(anf_rate)
all_cv.append(cv)
all_cn_rate.append(cn_rate)
anf_rate, cn_rate = all_anf_rate[0], all_cn_rate[0]
while cn_rate>cn_rate_min:
anf_rate = anf_rate-anf_rate*search_fraction
K = num_anf*anf_rate*tau
cv, cn_rate = ssd(mu=w*K, sigma=w*sqrt(K))
all_anf_rate.append(anf_rate)
all_cv.append(cv)
all_cn_rate.append(cn_rate)
all_anf_rate = array(all_anf_rate)
all_cn_rate = array(all_cn_rate)
all_cv = array(all_cv)
I = argsort(all_cn_rate)
all_anf_rate = all_anf_rate[I]
all_cn_rate = all_cn_rate[I]
all_cv = all_cv[I]
return all_anf_rate, all_cn_rate, all_cv
# This constructs the image column by column
@mem.cache
def get_image(num_anf_range, cn_rate_min, cn_rate_max, target_cn_rate, target_anf_rate,
search_fraction,
refractory=0.6*ms, tau=6*ms,
repeats=50, duration=150*ms, skip=50*ms,
):
rate_all = []
cv_all = []
num_anf_all = []
print 'Working on: '
for i, num_anf in enumerate(num_anf_range):
print '%d/%d ' % (i, len(num_anf_range)),
_, rate, cv = find_points(num_anf, cn_rate_min, cn_rate_max, target_cn_rate, target_anf_rate,
search_fraction, refractory=refractory, tau=tau, repeats=repeats,
duration=duration, skip=skip)
rate_all.append(rate)
cv_all.append(cv)
num_anf_all.append(full(len(cv), num_anf))
rate = hstack(rate_all)
cv = hstack(cv_all)
num_anf = hstack(num_anf_all)
return num_anf, rate, cv
# This does some nice plotting for the image
def plotmap(num_anf_range, cn_rate_min, cn_rate_max, target_cn_rate, target_anf_rate,
search_fraction,
refractory=0.6*ms, tau=6*ms,
repeats=50, duration=150*ms, skip=50*ms,
log_anf_scale=False, cmap_01=True,
levels=[0.15, 0.35, 0.8],
show_interpolated_data=True,
show_colorbar=True,
cmap=cm.YlGnBu_r,
):
num_anf_min = amin(num_anf_range)
num_anf_max = amax(num_anf_range)
num_anf, rate, cv = get_image(num_anf_range, cn_rate_min, cn_rate_max,
target_cn_rate, target_anf_rate, search_fraction,
refractory=refractory, tau=tau, repeats=repeats,
duration=duration, skip=skip)
num_anf_orig = num_anf
rate_orig = rate
num_anf = normed(num_anf, logscale=log_anf_scale)
rate = normed(rate)
img = interpolate_to_image(num_anf, rate, cv,
normed(num_anf_min, num_anf_orig, logscale=log_anf_scale),
normed(num_anf_max, num_anf_orig, logscale=log_anf_scale), 100,
normed(cn_rate_min, rate_orig), normed(cn_rate_max, rate_orig), 100,
)
if cmap_01:
vmin = 0
vmax = 1
else:
vmin = amin(img)
vmax = amax(img)
if show_interpolated_data:
imshow(img, origin='lower left', aspect='auto', cmap=cmap, vmin=vmin, vmax=vmax,
extent=(num_anf_min, num_anf_max, float(cn_rate_min), float(cn_rate_max)))
else:
scatter(num_anf_orig, rate_orig, c=cv, cmap=cmap, vmin=vmin, vmax=vmax, lw=0, s=150)
gca().set_axis_bgcolor('k')
if log_anf_scale:
xscale('log')
gca().xaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter('%d'))
if show_colorbar:
cb = colorbar()
cb.set_label('CV', rotation=270, labelpad=20)
xlabel('Number of inputs $N$')
ylabel('Output firing rate (sp/s)')
cs = contour(img, origin='lower', aspect='auto',
levels=levels, colors=['w']*len(levels),
extent=(num_anf_min, num_anf_max, float(cn_rate_min), float(cn_rate_max)))
clabel(cs, colors='w', inline=True)
# This function is passed to ipw.interact to control the GUI
def deafferentation_plot(tau_ms=6, refractory_ms=0.1,
anf_rate_100=200,
quality="Normal", show_interpolated_data=True):
tau = tau_ms*ms
refractory = refractory_ms*ms
cn_rate_min = 50*Hz
cn_rate_max = 300*Hz
target_cn_rate = 100*Hz
target_anf_rate = anf_rate_100*Hz
if quality=="Normal":
num_anf_range = arange(1, 51, 5)
search_fraction = 0.1
repeats = 50
duration = 150*ms
skip = 50*ms
else:
num_anf_range = arange(1, 51, 1)
search_fraction = 0.01
repeats = 4000
duration = 350*ms
skip = 100*ms
plotmap(num_anf_range, cn_rate_min, cn_rate_max,
target_cn_rate, target_anf_rate, search_fraction,
tau=tau, refractory=refractory,
repeats=repeats, duration=duration, skip=skip,
show_interpolated_data=show_interpolated_data)
widgets = dict(
tau_ms=ipw.FloatSlider(min=0.1, max=15, step=0.1, value=6,
continuous_update=False,
description=r"Membrane time constant $\tau$ (ms)"),
refractory_ms=ipw.FloatSlider(min=0, max=5, step=0.1, value=0.1,
continuous_update=False,
description=r"Refractory period $t_\mathrm{ref}$ (ms)"),
anf_rate_100=ipw.FloatSlider(min=50, max=500, step=10, value=200,
continuous_update=False,
description=r"AN firing rate for 100 sp/s output (sp/s)"),
quality=ipw.Dropdown(description="Quality",
options=["Normal", "Publication (very slow)"],
value="Normal"),
)
for w in widgets.values():
w.layout.width = '100%'
w.style = style = {'description_width': '30%'}
# Run the GUI
display(ipw.interact(deafferentation_plot, **widgets));