Note. The following notebook contains code in addition to text and figures. By default, the code has been hidden. You can click the icon that looks like an eye in the toolbar above to show the code. To run the code, click the cell menu, then "run all".

Deafferentation

In this notebook we investigate the effect of deafferentation on the model. Specifically, what happens when we reduce the number of fires, but through a homeostatic process, restore the original chopper cell firing rate? For simplicity, we only cover the case of excitatory inputs only (no inhibition).

The interactive figure below shows the number of inputs on the x-axis, and the output firing rate on the y-axis. The colour indicates the CV at that point. What is happening is that we modify the synaptic weight $w$ and input firing rate $\rho$ in order to achieve the range of output firing rates. Specifically, we have a pair of "reference" firing rates. By default, we always want an input firing rate of 200 sp/s to achieve an output firing rate of 100 sp/s (although this can be modified in the interactive figure). We use this reference pair to set the synaptic weight $w$, and then modify the input firing rate $\rho$ to cover the range of output firing rates.

Cautionary notes:

  • The computation is rather heavy for this figure, so it will take a couple of minutes for the first image to appear and for each change in parameters to be computed. Going back to a parameter set that you have already computed will be instantaneous though.
  • The "publication" quality will take a very long time to compute, and the difference is fairly minor so it's probably not worth using this option.
  • Occasionally, the searching algorithm (which is not very sophisticated) will get stuck for some parameters. At normal quality, if the figure doesn't appear after a few minutes, restart the computation. To do this, click the "Kernel" menu, then "Restart and run all".

In [ ]:
# Imports etc.
%matplotlib inline
from brian2 import *
from functools import partial
import ipywidgets as ipw
from matplotlib import cm
import matplotlib
import matplotlib.patheffects as PathEffects
from scipy.interpolate import Rbf
import joblib

import warnings
warnings.filterwarnings("ignore")

defaultclock.dt = 0.05*ms

mem = joblib.Memory(cachedir="joblib", verbose=0)

# Later on we generate a bunch of values at points that aren't on a
# grid: we use this to interpolate to a grid so we can display as an
# image. We use scipy's Radial basis function interpolation method.
def interpolate_to_image(x, y, c, xmin, xmax, xn, ymin, ymax, yn, smooth=1.0):
    X = linspace(xmin, xmax, xn)
    Y = linspace(ymin, ymax, yn)
    cx, cy = meshgrid(X, Y)
    f = Rbf(x, y, c, function='linear', smooth=smooth)
    cc = f(cx, cy)
    return cc

# Utility function
def normed(x, ref=None, logscale=False):
    if ref is None:
        ref = x
    x = x*1.0
    if isinstance(x, ndarray):
        x = array(x)
    ref = array(ref)
    if logscale:
        x = log(x)
        ref = log(ref)
    return (x-amin(ref))/(amax(ref)-amin(ref))

# The model itself: this is copied directly from level_dependence.ipynb
thenet = None
@mem.cache
def get_steady_state_data(
        repeats=1000,
        duration=250*ms,
        skip=50*ms,
        mu=4.0,
        sigma=0.1,
        tau=5*ms,
        refractory=0*ms,
        ):
    global thenet
    if thenet is None or len(thenet['G'])!=repeats:
        eqs = '''
        dv/dt = (mu-v)/tau+sigma*tau**-0.5*xi : 1 (unless refractory)
        refrac : second
        '''
        G = NeuronGroup(repeats, eqs, threshold='v>1', reset='v=0', refractory='refrac',
                        name='G', method='euler')
        M = SpikeMonitor(G, name='M')
        thenet = Network(G, M)
        thenet.store()
    else:
        G = thenet['G']
        M = thenet['M']
    thenet.restore()
    G.refrac = refractory
    G.not_refractory = True
    G.lastspike = -inf*second
    G.v = 0
    ns = {'mu': mu, 'tau': tau, 'sigma': sigma}
    M.active = False
    thenet.run(skip, namespace=ns)
    M.active = True
    thenet.run(duration-skip, namespace=ns)    
    trains = M.spike_trains()
    dtrains = [diff(train) for train in trains.values() if len(train)>1]
    if len(dtrains):
        isi = hstack(dtrains)*second
    else:
        isi = array([])
    if len(isi)>1:
        cv = std(isi)/mean(isi)
    else:
        cv = nan
    rate = len(M.t)/(repeats*(duration-skip))
    return cv, rate

# This function finds the points on a particular column, i.e. the points
# for a fixed number N=num_anf of input neurons. Given a target output
# firing rate and input firing rate, it finds the weight w that gives
# that output rate for that input rate. Then, given that weight, it
# modifies the input firing rate up and down until the full range of
# output firing rates has been covered. The rate at which it moves up
# and down is a parameter (search_fraction).
@mem.cache
def find_points(num_anf, cn_rate_min, cn_rate_max, target_cn_rate, target_anf_rate, search_fraction,
                refractory=0.6*ms, tau=6*ms,
                repeats=50, duration=150*ms, skip=50*ms, spikes_per_repeat=100,
                ):
    ssd = partial(get_steady_state_data, repeats=repeats, duration=duration, skip=skip,
                  refractory=refractory, tau=tau)
    K = num_anf*target_anf_rate*tau
    # step 1: search for w giving cn_rate. This is a stochastic search but we just treat it as if
    # it were not and hope that works well enough. 
    # Start with a guess based on the case when sigma = 0
    # mu = w*K, sigma=w*sqrt(K)
    # predicted rate R = -1/(tau*log(1-1/mu))
    # so want mu=1/(1-exp(-1/(tau*R)))=w*K
    w = 1.0/(1.0-exp(-1/(tau*target_cn_rate)))/K
    cv, rate = ssd(mu=w*K, sigma=w*sqrt(K))
    # now search in the direction we would need to move
    w_prev, rate_prev = w, rate
    #while 1:
    for _ in range(100):
        dw = w*search_fraction*sign(target_cn_rate-rate)
        w = w_prev+dw
        _, rate = ssd(mu=w*K, sigma=w*sqrt(K))
        if ((rate_prev<target_cn_rate and rate>target_cn_rate) or
            (rate_prev>target_cn_rate and rate<target_cn_rate)):
            # changed from above to below, now linearly interpolate
            break
        w_prev = w
        rate_prev = rate
    w = w_prev+(target_cn_rate-rate_prev)*(w-w_prev)/(rate-rate_prev)
    _, cn_rate = ssd(mu=w*K, sigma=w*sqrt(K))
    # Step 2:
    # At this point we should have a good value of w, now we just need to search
    # up and down varying anf_rate to get to min and max values of cn_rate, storing everything as
    # we go
    all_anf_rate = [target_anf_rate]
    all_cv = [cv]
    all_cn_rate = [cn_rate]
    anf_rate = target_anf_rate
    while cn_rate<cn_rate_max:
        anf_rate = anf_rate+anf_rate*search_fraction
        K = num_anf*anf_rate*tau
        cv, cn_rate = ssd(mu=w*K, sigma=w*sqrt(K))
        all_anf_rate.append(anf_rate)
        all_cv.append(cv)
        all_cn_rate.append(cn_rate)
    anf_rate, cn_rate = all_anf_rate[0], all_cn_rate[0]
    while cn_rate>cn_rate_min:
        anf_rate = anf_rate-anf_rate*search_fraction
        K = num_anf*anf_rate*tau
        cv, cn_rate = ssd(mu=w*K, sigma=w*sqrt(K))
        all_anf_rate.append(anf_rate)
        all_cv.append(cv)
        all_cn_rate.append(cn_rate)
    all_anf_rate = array(all_anf_rate)
    all_cn_rate = array(all_cn_rate)
    all_cv = array(all_cv)
    I = argsort(all_cn_rate)
    all_anf_rate = all_anf_rate[I]
    all_cn_rate = all_cn_rate[I]
    all_cv = all_cv[I]
    return all_anf_rate, all_cn_rate, all_cv

# This constructs the image column by column
@mem.cache
def get_image(num_anf_range, cn_rate_min, cn_rate_max, target_cn_rate, target_anf_rate,
              search_fraction,
              refractory=0.6*ms, tau=6*ms,
              repeats=50, duration=150*ms, skip=50*ms,
              ):
    rate_all = []
    cv_all = []
    num_anf_all = []
    print 'Working on: '
    for i, num_anf in enumerate(num_anf_range):
        print '%d/%d ' % (i, len(num_anf_range)),
        _, rate, cv = find_points(num_anf, cn_rate_min, cn_rate_max, target_cn_rate, target_anf_rate,
                                  search_fraction, refractory=refractory, tau=tau, repeats=repeats,
                                  duration=duration, skip=skip)
        rate_all.append(rate)
        cv_all.append(cv)
        num_anf_all.append(full(len(cv), num_anf))
    rate = hstack(rate_all)
    cv = hstack(cv_all)
    num_anf = hstack(num_anf_all)
    return num_anf, rate, cv

# This does some nice plotting for the image
def plotmap(num_anf_range, cn_rate_min, cn_rate_max, target_cn_rate, target_anf_rate,
            search_fraction,
            refractory=0.6*ms, tau=6*ms,
            repeats=50, duration=150*ms, skip=50*ms,
            log_anf_scale=False, cmap_01=True,
            levels=[0.15, 0.35, 0.8],
            show_interpolated_data=True,
            show_colorbar=True,
            cmap=cm.YlGnBu_r,
            ):
    num_anf_min = amin(num_anf_range)
    num_anf_max = amax(num_anf_range)
    num_anf, rate, cv = get_image(num_anf_range, cn_rate_min, cn_rate_max,
                                  target_cn_rate, target_anf_rate, search_fraction,
                                  refractory=refractory, tau=tau, repeats=repeats,
                                  duration=duration, skip=skip)
    num_anf_orig = num_anf
    rate_orig = rate
    num_anf = normed(num_anf, logscale=log_anf_scale)
    rate = normed(rate)
    img = interpolate_to_image(num_anf, rate, cv,
                       normed(num_anf_min, num_anf_orig, logscale=log_anf_scale),
                            normed(num_anf_max, num_anf_orig, logscale=log_anf_scale), 100,
                       normed(cn_rate_min, rate_orig), normed(cn_rate_max, rate_orig), 100,
                       )
    if cmap_01:
        vmin = 0
        vmax = 1
    else:
        vmin = amin(img)
        vmax = amax(img)
    if show_interpolated_data:
        imshow(img, origin='lower left', aspect='auto', cmap=cmap, vmin=vmin, vmax=vmax,
               extent=(num_anf_min, num_anf_max, float(cn_rate_min), float(cn_rate_max)))
    else:
        scatter(num_anf_orig, rate_orig, c=cv, cmap=cmap, vmin=vmin, vmax=vmax, lw=0, s=150)
        gca().set_axis_bgcolor('k')
    if log_anf_scale:
        xscale('log')
    gca().xaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter('%d'))
    if show_colorbar:
        cb = colorbar()
        cb.set_label('CV', rotation=270, labelpad=20)
    xlabel('Number of inputs $N$')
    ylabel('Output firing rate (sp/s)')
    cs = contour(img, origin='lower', aspect='auto',
                 levels=levels, colors=['w']*len(levels),
                 extent=(num_anf_min, num_anf_max, float(cn_rate_min), float(cn_rate_max)))
    clabel(cs, colors='w', inline=True)

# This function is passed to ipw.interact to control the GUI
def deafferentation_plot(tau_ms=6, refractory_ms=0.1,
                         anf_rate_100=200,
                         quality="Normal", show_interpolated_data=True):
    tau = tau_ms*ms
    refractory = refractory_ms*ms
    
    cn_rate_min = 50*Hz
    cn_rate_max = 300*Hz
    target_cn_rate = 100*Hz
    target_anf_rate = anf_rate_100*Hz
    if quality=="Normal":
        num_anf_range = arange(1, 51, 5)
        search_fraction = 0.1
        repeats = 50
        duration = 150*ms
        skip = 50*ms
    else:
        num_anf_range = arange(1, 51, 1)
        search_fraction = 0.01
        repeats = 4000
        duration = 350*ms
        skip = 100*ms

    plotmap(num_anf_range, cn_rate_min, cn_rate_max,
            target_cn_rate, target_anf_rate, search_fraction,
            tau=tau, refractory=refractory,
            repeats=repeats, duration=duration, skip=skip,
            show_interpolated_data=show_interpolated_data)

widgets = dict(
    tau_ms=ipw.FloatSlider(min=0.1, max=15, step=0.1, value=6,
                continuous_update=False,
                description=r"Membrane time constant $\tau$ (ms)"),
    refractory_ms=ipw.FloatSlider(min=0, max=5, step=0.1, value=0.1,
                continuous_update=False,
                description=r"Refractory period $t_\mathrm{ref}$ (ms)"),
    anf_rate_100=ipw.FloatSlider(min=50, max=500, step=10, value=200,
                continuous_update=False,
                description=r"AN firing rate for 100 sp/s output (sp/s)"),
    quality=ipw.Dropdown(description="Quality",
                options=["Normal", "Publication (very slow)"],
                value="Normal"),    
    )
for w in widgets.values():
    w.layout.width = '100%'
    w.style = style = {'description_width': '30%'}
    
# Run the GUI    
display(ipw.interact(deafferentation_plot, **widgets));