Note. The following notebook contains code in addition to text and figures. By default, the code has been hidden. You can click the icon that looks like an eye in the toolbar above to show the code. To run the code, click the cell menu, then "run all".

Level dependence

In this notebook, we examine the level dependence of the CV and firing rate of chopper cells. We compare experimental data with the results that can be achieved with the model introduced in the notebooks on the Basic Model and Behaviour Maps.


In [ ]:
%%html
<!-- hack to improve styling of ipywidgets sliders -->
<style type="text/css">
.widget-label {
    min-width: 35ex;
    max-width: 35ex;
}
.widget-hslider {
    width: 100%;
}
.widget-hprogress {
    width: 100%;
}

</style>

In [ ]:
# Imports etc.
%matplotlib inline
from brian2 import *
from model_explorer_jupyter import *
from mpl_toolkits.axes_grid1 import make_axes_locatable
import numpy.random as np_rand
from functools import partial
from collections import OrderedDict
import ipywidgets as ipw
from scipy import stats
from matplotlib import cm as mplcm
import matplotlib.patheffects as PathEffects

import warnings
warnings.filterwarnings("ignore")
BrianLogger.log_level_error()

import joblib
mem = joblib.Memory(cachedir="joblib", verbose=0)

progress_slider, update_progress = brian2_progress_reporter()

defaultclock.dt = 0.05*ms

# Load experimental data summary stats
exp_cv20 = loadtxt('expdata_summary/allcv20.txt')
exp_cv50 = loadtxt('expdata_summary/allcv50.txt')
exp_fr20 = loadtxt('expdata_summary/allfr20.txt')
exp_fr50 = loadtxt('expdata_summary/allfr50.txt')
crossing = ((exp_cv20>0.35)&(exp_cv50<0.35))|((exp_cv20<0.35)&(exp_cv50>0.35))
exp_frdiff = exp_fr50-exp_fr20
exp_cvdiff = exp_cv50-exp_cv20

# Plotting functions

# Utility function to create axes on the top and right
# We use this below to create histograms on the sides of a plot
def get_sidehist_axes(ax=None):
    if ax is None:
        ax = gca()
    divider = make_axes_locatable(ax)
    ax_top = divider.append_axes("top", size=0.5, pad=0.0, sharex=ax)
    setp(ax_top.get_xticklabels(), visible=False)
    setp(ax_top.get_xticklines(), visible=False)
    setp(ax_top.get_yticklabels(), visible=False)
    setp(ax_top.get_yticklines(), visible=False)
    ax_top.set_frame_on(False)
    ax_right = divider.append_axes("right", size=0.5, pad=0.0, sharey=ax)
    setp(ax_right.get_xticklabels(), visible=False)
    setp(ax_right.get_xticklines(), visible=False)
    setp(ax_right.get_yticklabels(), visible=False)
    setp(ax_right.get_yticklines(), visible=False)
    ax_right.set_frame_on(False)
    sca(ax)
    return ax_top, ax_right

# Plot CV and firing rate, pass varx=fr20, vary=cv20 or fr50, cv50 to
# plot panels A and B
def pointplot(fr20, fr50, cv20, cv50, varx, vary, dolabel=False, col=None, ls=None,
              sidehists=True, muted=False):
    if muted:
        col = (0.5, 0.5, 1)
    I = (cv20<0.35)*(cv50<0.35)
    plot(varx[I], vary[I], ls if ls else '^',
         ms=6, mec='none', c=col if col else 'b', 
         label='Sustained' if dolabel else None)
    if muted:
        col = (0.5, 1, 0.5)
    I = (cv20>0.35)*(cv50>0.35)
    plot(varx[I], vary[I], ls if ls else 's',
         ms=6, mec='none', c=col if col else 'g',
         label='Transient' if dolabel else None)
    if muted:
        col = (1, 0.5, 0.5)
    I = ((cv20<0.35)*(cv50>0.35))+((cv20>0.35)*(cv50<0.35))
    plot(varx[I], vary[I], ls if ls else 'o',
         ms=6, mec='none', c=col if col else 'r',
         label='Mixed' if dolabel else None)
    if sidehists:
        ax_top, ax_right = get_sidehist_axes()
        ax_top.hist(varx, 20, facecolor=(0.0,)*3, ec='none')
        ax_right.hist(vary, 20, facecolor=(0.0,)*3, ec='none', orientation='horizontal')

# Plot panel C
def diffplot(fr20, fr50, cv20, cv50, dolabel=True, muted=False):
    cn_rate_diffs = fr50-fr20
    cv_diffs = cv50-cv20
    pointplot(fr20, fr50, cv20, cv50, cn_rate_diffs, cv_diffs, dolabel=dolabel, muted=muted)
    xlabel('Firing rate difference (sp/s)')
    ylabel('CV difference')
    axhline(0, ls='-', c='k')
    axvline(0, ls='-', c='k')
    xlim(-150, 150)
    ylim(-0.25, 0.25)

# This and the next function plot panel D
def arrowplot(allfr20, allfr50, allcv20, allcv50, arrowlength=1.0, muted=False):
    for fr20, fr50, cv20, cv50 in zip(allfr20, allfr50, allcv20, allcv50):
        #arrow(fr20, cv20, 0.1*(fr50-fr20), 0.1*(cv50-cv20))#, head_width=0.05, head_length=0.1, fc='k', ec='k')
        if cv20<0.35 and cv50<0.35:
            c = 'b'
            if muted:
                c = (0.5, 0.5, 1)
        elif cv20>0.35 and cv50>0.35:
            c = 'g'
            if muted:
                c = (0.5, 1, 0.5)
        else:
            c = 'r'
            if muted:
                c = (1, 0.5, 0.5)
        annotate('', xytext=(fr20, cv20),
                 xy=(fr20+arrowlength*(fr50-fr20), cv20+arrowlength*(cv50-cv20)),
                 arrowprops=dict(arrowstyle='->', ec=c))
    xlim(min(amin(allfr20), amin(allfr50)), max(amax(allfr20), amax(allfr50)))
    ylim(min(amin(allcv20), amin(allcv50)), max(amax(allcv20), amax(allcv50)))
    axhline(0.35, ls='--', c='k')
    xlabel('Firing rate(sp/s)')
    ylabel('CV')
    xlim(50, 500)
    ylim(0.1, 0.7)

def diff_arrow_plot(allfr20, allfr50, allcv20, allcv50, dolabel=True, arrowlength=1.0, muted=False):
    subplot(121)
    diffplot(allfr20, allfr50, allcv20, allcv50, dolabel=dolabel, muted=muted)        
    subplot(122)
    arrowplot(allfr20, allfr50, allcv20, allcv50, arrowlength=arrowlength, muted=muted)
    ax_top, ax_right = get_sidehist_axes()
    if dolabel:
        for i in xrange(2):
            figtext(0.01+i/2.0, 0.95, chr(ord('A')+i), size='x-large')

# Plot all panels
def cvfr_level_dependence_plot(allfr20, allfr50, allcv20, allcv50,
                               muted=False, axes=None,
                               highlighted=None):
    if axes is None:
        figure(figsize=(7, 6))
        ax20 = subplot(221)
        ax50 = subplot(222)
        axdiff = subplot(223)
        axarrow = subplot(224)
    else:
        ax20, ax50, axdiff, axarrow = axes
    if highlighted is not None:
        hfr20, hfr50, hcv20, hcv50 = highlighted
    sca(ax20)
    pointplot(allfr20, allfr50, allcv20, allcv50, allfr20, allcv20, muted=muted)
    if highlighted is not None:
        plot(hfr20, hcv20, '*k', ms=12)
    xlabel('Firing rate (sp/s)')
    ylabel('CV')
    title('20 dB re threshold', y=0.87)
    ylim(0.1, 0.7)
    xlim(0, 500)
    axhline(0.35, ls='--', c='k')

    sca(ax50)
    pointplot(allfr20, allfr50, allcv20, allcv50, allfr50, allcv50, muted=muted)
    if highlighted is not None:
        plot(hfr50, hcv50, '*k', ms=12)
    xlabel('Firing rate (sp/s)')
    ylabel('CV')
    title('50 dB re threshold', y=0.87)
    ylim(0.1, 0.7)
    xlim(0, 500)
    axhline(0.35, ls='--', c='k')

    sca(axdiff)
    diffplot(allfr20, allfr50, allcv20, allcv50, muted=muted)
    if highlighted is not None:
        plot(hfr50-hfr20, hcv50-hcv20, '*k', ms=12, label="Model")
    legend(loc='lower left', fontsize='small')

    sca(axarrow)
    arrowplot(allfr20, allfr50, allcv20, allcv50, muted=muted)
    if highlighted is not None:
        arrowlength = 1.0
        annotate('', xytext=(hfr20, hcv20),
                 xy=(hfr20+arrowlength*(hfr50-hfr20), hcv20+arrowlength*(hcv50-hcv20)),
                 arrowprops=dict(arrowstyle='->', ec='k', lw=2))
    ax_top, ax_right = get_sidehist_axes()
    xlabel('Firing rate (sp/s)')
    ylabel('CV')

    if axes is None:
        for i in xrange(2):
            for j in xrange(2):
                figtext(0.01+i/2.0, 0.95-j/2.0, chr(ord('A')+i+2*j), size='x-large')

        tight_layout()

We start with a plot of the experimental data (a new analysis of data recorded in the lab of Ian Winter over a number of years). This shows that the CV and firing rate of chopper cells changes at different sound levels. The upper two panels show the distribution of these quantities at 20 and 50 dB sound levels, and the lower plots show the differences as points or arrows. The points in blue are those whose CV is less than 0.35 at both levels (so unambiguous sustained choppers), in green if the CV is higher than 0.35 at both levels (unambiguous transient), or in red if the CV crosses this boundary between the two levels.


In [ ]:
cvfr_level_dependence_plot(exp_fr20, exp_fr50, exp_cv20, exp_cv50)

In [ ]:
thenet = None
def get_steady_state_data(
        repeats=1000,
        duration=250*ms,
        skip=50*ms,
        mu=4.0,
        sigma=0.1,
        tau=5*ms,
        refractory=0*ms,
        ):
    global thenet
    if thenet is None or len(thenet['G'])!=repeats:
        eqs = '''
        dv/dt = (mu-v)/tau+sigma*tau**-0.5*xi : 1 (unless refractory)
        refrac : second
        '''
        G = NeuronGroup(repeats, eqs, threshold='v>1', reset='v=0', refractory='refrac',
                        name='G', method='euler')
        M = SpikeMonitor(G, name='M')
        thenet = Network(G, M)
        thenet.store()
    else:
        G = thenet['G']
        M = thenet['M']
    thenet.restore()
    G.refrac = refractory
    G.not_refractory = True
    G.lastspike = -inf*second
    G.v = 0
    ns = {'mu': mu, 'tau': tau, 'sigma': sigma}
    M.active = False
    thenet.run(skip, namespace=ns)
    M.active = True
    thenet.run(duration-skip, namespace=ns)    
    trains = M.spike_trains()
    dtrains = [diff(train) for train in trains.values() if len(train)>1]
    if len(dtrains):
        isi = hstack(dtrains)*second
    else:
        isi = array([])
    if len(isi)>1:
        cv = std(isi)/mean(isi)
    else:
        cv = nan
    rate = len(M.t)/(repeats*(duration-skip))
    return cv, rate

In the following interactive figure, you can see the effect of having a different excitatory and inhibitory firing rate at 20 and 50 dB for different model parameters. In general, higher input firing rates will decrease the CV and increase the output firing rate. Higher inhibition will increase the CV and decrease the firing rate. The model result is shown with a black point or arrow, and the experimental results as above but in a lighter shade.


In [ ]:
def compare_model_to_data(rho20_Hz=150, rho50_Hz=200, N=40, alpha20=0.0, alpha50=0.5,
                          mu=2.0, tau_ms=6.0, refractory_ms=0.1):
    # Parameters
    tau = tau_ms*ms
    refractory = refractory_ms*ms
    rho20 = rho20_Hz*Hz
    rho50 = rho50_Hz*Hz
    # Compute synaptic weight
    weight = mu/(N*tau*0.5*(rho20*(1-alpha20)+rho50*(1-alpha50)))
    # Get model chopper cell results
    def f(weight, N, anf_rate_exc, anf_rate_inh):
        tau_exc = tau_inh = tau
        mu_exc = weight*N*tau*anf_rate_exc
        mu_inh = weight*N*tau*anf_rate_inh
        sigma2_exc = weight*mu_exc
        sigma2_inh = weight*mu_inh
        return get_steady_state_data(mu=mu_exc-mu_inh,
                                     sigma=sqrt(sigma2_exc+sigma2_inh))
    cv20, fr20 = f(weight, N, rho20, rho20*alpha20)
    cv50, fr50 = f(weight, N, rho50, rho50*alpha50)
    cvfr_level_dependence_plot(exp_fr20, exp_fr50, exp_cv20, exp_cv50, muted=True,
                               highlighted=[fr20, fr50, cv20, cv50])
 
widgets = OrderedDict([ # using an ordered dict doesn't work, not sure if there is a way?
        ('N', ipw.IntSlider(min=1, max=100, step=1, value=40,
                continuous_update=False,
                description=r"Number of AN fibres $N$")),
        ('mu', ipw.FloatSlider(min=0, max=5, step=0.01, value=2.0,
                continuous_update=False,
                description=r"Mean current $\mu$")),
        ('tau_ms', ipw.FloatSlider(min=0.1, max=15, step=0.1, value=6,
                continuous_update=False,
                description=r"Membrane time constant $\tau$ (ms)")),
        ('refractory_ms', ipw.FloatSlider(min=0, max=5, step=0.1, value=0.1,
                continuous_update=False,
                description=r"Refractory period $t_\mathrm{ref}$ (ms)")),
        ('rho20_Hz', ipw.FloatSlider(min=0, max=500, step=10, value=150,
                continuous_update=False,
                description=r"AN firing rate at 20 dB $\rho_{20}$ (Hz)")),
        ('rho50_Hz', ipw.FloatSlider(min=0, max=500, step=10, value=200,
                continuous_update=False,
                description=r"AN firing rate at 50 dB $\rho_{50}$ (Hz)")),
        ('alpha20', ipw.FloatSlider(min=0, max=1, step=0.01, value=0.0,
                continuous_update=False,
                description=r"Inhibitory fraction at 20 dB $\alpha_{20}$")),
        ('alpha50', ipw.FloatSlider(min=0, max=1, step=0.01, value=0.4,
                continuous_update=False,
                description=r"Inhibitory fraction at 50 dB $\alpha_{50}$")),
    ])

display(ipw.interact(compare_model_to_data, **widgets));

In the next figure (which will appear after a minute or so of computation time), you can see a model of the experimental data in the first figure above. To get this model, we have chosen a particular distribution of model parameters, which are the same at 20 and 50 dB levels except for the excitatory and inhibitory firing rates. The excitatory firing rates always increase, whereas the inhibitory rates can increase or decrease from 20 dB to 50 dB.

This figure is not interactive as it takes a little while to compute and because tweaking the parameters of this distribution is not very interesting. If you do want to try out other parameter distributions and see what they look like, show the code (by clicking the eye icon in the toolbox above, and modify the function parameter_distribution below, and the min_rate and max_rate parameters in the line starting params near the bottom. Then re-run the cell by clicking it and pressing Ctrl+Enter.


In [ ]:
def sigmoid(x, k):
    return 1./(1+exp(k*(1-2*x)))

def parameter_distribution():
    rechoose = True
    while rechoose:
        mu = randn()*.4+2
        k = 6
        inh_max = 0.65
        inh_1 = sigmoid(rand(), k)*inh_max
        inh_2 = sigmoid(rand(), k)*inh_max
        anf_rate_1 = randn()*25*Hz+250*Hz
        anf_rate_2 = randn()*25*Hz+250*Hz
        num_anf = np_rand.randint(30, 60+1)
        tau = exp(np_rand.uniform(log(5), log(15.0)))*ms
        refractory = exp(np_rand.uniform(log(0.1), log(5.0)))*ms

        rechoose = False
        if mu<1 or mu>4:
            rechoose = True
        if anf_rate_1<150*Hz or anf_rate_2<150*Hz or anf_rate_1>450*Hz or anf_rate_2>450*Hz:
            rechoose = True
        if anf_rate_2>anf_rate_1+75*Hz:
            rechoose = True
        if anf_rate_2<anf_rate_1:
            rechoose = True
        if inh_2-inh_1>0.25:
            rechoose = True
        if inh_2-inh_1<-0.1:
            rechoose = True

    return (mu, num_anf, anf_rate_1, anf_rate_2, inh_1, inh_2,
            dict(tau=tau, refractory=refractory))

def predict_level_dependence(
          N, dist,
          min_mean_rate=0*Hz, max_mean_rate=inf*Hz,
          min_rate=0*Hz, max_rate=inf*Hz,
          repeats=100, duration=100*ms, skip=15*ms,
          **params):
    ssd = partial(get_steady_state_data, repeats=repeats, duration=duration, skip=skip, **params)
    def f(weight, num_anf, anf_rate_exc, anf_rate_inh):
        tau_exc = tau_inh = tau
        weight_exc = weight
        weight_inh = inh_skew*weight
        num_anf_exc = num_anf
        num_anf_inh = num_anf/inh_skew
        mu_exc = weight_exc*num_anf_exc*tau_exc*anf_rate_exc
        mu_inh = weight_inh*num_anf_inh*tau_inh*anf_rate_inh
        sigma2_exc = weight_exc*mu_exc
        sigma2_inh = weight_inh*mu_inh
        return ssd(mu=mu_exc-mu_inh, sigma=sqrt(sigma2_exc+sigma2_inh))
    all_cv_1 = []
    all_cv_2 = []
    all_cn_rate_1 = []
    all_cn_rate_2 = []
    while len(all_cv_1)<N:
        p = dist()
        if len(p)==7:
            additional_params = p[6]
            params.update(**additional_params)
        tau = params.get('tau', 6*ms)
        inh_skew = params.get('inh_skew', 1.0)
        mu_base, num_anf, anf_rate_1, anf_rate_2, inh_1, inh_2 = p[:6]
        weight = mu_base/(num_anf*tau*0.5*(anf_rate_1*(1-inh_1)+anf_rate_2*(1-inh_2)))
        cv_1, cn_rate_1 = f(weight, num_anf, anf_rate_1, anf_rate_1*inh_1)
        cv_2, cn_rate_2 = f(weight, num_anf, anf_rate_2, anf_rate_2*inh_2)
        mean_rate = 0.5*(cn_rate_1+cn_rate_2)
        if mean_rate<min_mean_rate or mean_rate>max_mean_rate:
            continue
        if cn_rate_1<min_rate or cn_rate_2<min_rate or cn_rate_1>max_rate or cn_rate_2>max_rate:
            continue
        all_cv_1.append(cv_1)
        all_cv_2.append(cv_2)
        all_cn_rate_1.append(cn_rate_1)
        all_cn_rate_2.append(cn_rate_2)
    return array(all_cv_1), array(all_cv_2), array(all_cn_rate_1), array(all_cn_rate_2)

params = [dict(N=86, dist=parameter_distribution, min_rate=100*Hz, max_rate=450*Hz)]
cv1, cv2, rate1, rate2 = map(hstack, zip(*[predict_level_dependence(**p) for p in params]))
cvfr_level_dependence_plot(rate1, rate2, cv1, cv2)

The next interactive figure is a density plot of the behaviour shown above. In this case, for each of the parameters of the model, you can specify a uniform distribution by setting the minimum and maximum values of the distribution. We then generate a few hundred or thousand sets of parameters following this distribution. We reject those that have a CV value outside the range (0.1, 0.7) and a firing rate outside (100, 500) spikes per second. Finally, we plot the firing rate and CV at 20 dB and 50 dB, and the change in firing rate when moving from 20 to 50 dB levels as density plots (darker colours = more dense). We plot the experimental data on top of this density plot as black dots.

The density plots are generated using kernel density estimation, using the built in method in scipy.


In [ ]:
#### REDUCED MODEL
## Parameters: mu, sigma, tau, refractory

# Same code as in basic notebook

@mem.cache
def runmodel_raw(N, mu=2.0, sigma=0.1, tau_ms=6, refractory_ms=0.6,
                 repeats=50, duration=150*ms, skip=50*ms):
    tau = tau_ms*ms
    refractory = refractory_ms*ms

    eqs = '''
    dv/dt = (mu-v)/tau+sigma*tau**-0.5*xi : 1 (unless refractory)
    refrac : second
    tau : second
    mu : 1
    sigma : 1
    '''
    
    G = NeuronGroup(N, eqs, threshold='v>1', reset='v=0',
                    refractory='refrac', method='Euler')
    G.refrac = refractory
    G.tau = tau
    G.mu = mu
    G.sigma = sigma
    M = SpikeMonitor(G)
    
    run(duration, report=update_progress, report_period=1*second)
    
    CV = zeros(N)
    FR = zeros(N)
    trains = M.spike_trains()
    for i in xrange(N):
        t = trains[i]
        t = t[t>skip]
        t.sort() # for older Brian versions
        isi = diff(t)
        CV[i] = std(isi)/mean(isi)
        FR[i] = len(t)/(duration-skip)
    return CV, FR

def runmodel_raw_rel(N, mu=2.0, sigma_mu=0.1, tau_ms=6, refractory_ms=0.6,
                     repeats=50, duration=150*ms, skip=50*ms):
    return runmodel_raw(N, mu=mu, sigma=sigma_mu*mu, tau_ms=tau_ms,
                        refractory_ms=refractory_ms, repeats=repeats,
                        duration=duration, skip=skip)

In [ ]:
#### BASIC MODEL
## Parameters mu, N, tau, alpha, t_ref, rho_E

# Note that some of the parameter names are different to the paper
# because this code was written before the final set of names was
# settled on.

# This is the same code as in the first notebook, except that the
# synaptic weight is calibrated using the _base ANF rate and
# fraction of inhibition, allowing us to have the same synaptic
# weight at 20 and 50 dB levels, even though the ANF rate and
# inhibitory fraction may be different.

def runmodel_std_base_weight(
                 N, mu=2.0, num_anf=50, tau_ms=6,
                 inh=0.0, inh_skew=1.0,
                 refractory_ms=0.6, anf_rate_hz=250,
                 anf_rate_base_hz=250, inh_base=0.0,
                 repeats=50, duration=150*ms, skip=50*ms):
    tau = tau_ms*ms
    refractory = refractory_ms*ms
    anf_rate = anf_rate_hz*Hz
    anf_rate_base = anf_rate_base_hz*Hz
    
    tau_exc = tau_inh = tau
    weight = mu/(num_anf*tau*anf_rate_base*(1-inh_base))
    weight_exc = weight
    weight_inh = inh_skew*weight
    num_anf_exc = num_anf
    num_anf_inh = num_anf/inh_skew
    anf_rate_exc = anf_rate
    anf_rate_inh = anf_rate*inh
    mu_exc = weight_exc*num_anf_exc*tau_exc*anf_rate_exc
    mu_inh = weight_inh*num_anf_inh*tau_inh*anf_rate_inh
    sigma2_exc = weight_exc*mu_exc
    sigma2_inh = weight_inh*mu_inh
    mu = mu_exc-mu_inh
    sigma = sqrt(sigma2_exc+sigma2_inh)
    return runmodel_raw(N, mu=mu, sigma=sigma, tau_ms=tau_ms,
                        refractory_ms=refractory_ms,
                        repeats=repeats, duration=duration, skip=skip)

In [ ]:
# Take a series of x, y points and plot a density map using kernel density estimation
# N is the grid size for the density image

def density_map(x, y, N, xmin=None, xmax=None, ymin=None, ymax=None):
    # Peform the kernel density estimate
    if xmin is None:
        xmin = amin(x)
    if xmax is None:
        xmax = amax(x)
    if ymin is None:
        ymin = amin(y)
    if ymax is None:
        ymax = amax(y)
    xx, yy = mgrid[xmin:xmax:N*1j, ymin:ymax:N*1j]
    positions = vstack([xx.ravel(), yy.ravel()])
    values = vstack([x, y])
    kernel = stats.gaussian_kde(values)
    f = np.reshape(kernel(positions).T, xx.shape)
    extent = (xmin, xmax, ymin, ymax)
    return f.T, extent

def plot_density_map(x, y, N, xmin=None, xmax=None, ymin=None, ymax=None, cmap=mplcm.afmhot_r, **args):
    img, extent = density_map(x, y, N, xmin=xmin, ymin=ymin, xmax=xmax, ymax=ymax)
    imshow(img, origin='lower left', aspect='auto', interpolation='nearest',
           extent=extent, cmap=cmap,
           vmin=0, vmax=amax(img)/0.7,
           **args
           )

In [ ]:
# This is the model used for the density plot figures, samples N random values from the
# distributions given, and then computes their cv/fr at 20/50 dB

@mem.cache
def compute_sample_model_from_param_distributions(
            N,
            mu_dist, num_anf_dist, tau_dist, inh_dist, refractory_dist, anf_rate_dist,
            delta_anf_rate_dist, delta_inh_dist,
            duration=1050*ms, skip=50*ms):
    # Generate parameters from the distributions
    mu = mu_dist.rvs(N)
    num_anf = num_anf_dist.rvs(N)
    tau = tau_dist.rvs(N)
    refractory = refractory_dist.rvs(N)
    inh_20 = inh_dist.rvs(N)
    delta_inh = delta_inh_dist.rvs(N)
    inh_50 = inh_20+delta_inh
    anf_rate_20 = anf_rate_dist.rvs(N)
    delta_anf_rate = delta_anf_rate_dist.rvs(N)
    anf_rate_50 = anf_rate_20+delta_anf_rate
    # run the model
    cv_20, fr_20 = runmodel_std_base_weight(N, mu=mu, num_anf=num_anf, tau_ms=tau, inh=inh_20,
                                refractory_ms=refractory, anf_rate_hz=anf_rate_20,
                                inh_base=inh_20, anf_rate_base_hz=anf_rate_20,
                                duration=duration, skip=skip)
    cv_50, fr_50 = runmodel_std_base_weight(N, mu=mu, num_anf=num_anf, tau_ms=tau, inh=inh_50,
                                refractory_ms=refractory, anf_rate_hz=anf_rate_50,
                                inh_base=inh_20, anf_rate_base_hz=anf_rate_20,
                                duration=duration, skip=skip)
    return cv_20, fr_20, cv_50, fr_50, inh_50, delta_inh, delta_anf_rate

# Shorthand
class uniform(object):
    def __init__(self, low, high):
        self.low = low
        self.high=high
    def rvs(self, N):
        return stats.uniform(loc=self.low, scale=self.high-self.low).rvs(N)

In [ ]:
# _dist parameters should be distributions from scipy.stats
def sample_model_from_param_distributions(
            N,
            mu_dist, num_anf_dist, tau_dist, inh_dist, refractory_dist, anf_rate_dist,
            delta_anf_rate_dist, delta_inh_dist,
            cv_min=0.1, cv_max=0.7, fr_min=100, fr_max=500,
            duration=1050*ms, skip=50*ms,
            show_absolute_plots=True,
            labelling=True, showexpdata=True, lw=2, arrowsize=15, mapdim=200, showarrows=True,
            classification_20='all'):
    cv_20, fr_20, cv_50, fr_50, inh_50, delta_inh, delta_anf_rate = compute_sample_model_from_param_distributions(
            N,
            mu_dist, num_anf_dist, tau_dist, inh_dist, refractory_dist, anf_rate_dist,
            delta_anf_rate_dist, delta_inh_dist,
            duration=duration, skip=skip)
    I = (inh_50<1)*(inh_50>0)
    I = I*(cv_20>cv_min)*(cv_20<cv_max)*(fr_20>fr_min)*(fr_20<fr_max)
    I = I*(cv_50>cv_min)*(cv_50<cv_max)*(fr_50>fr_min)*(fr_50<fr_max)
    if classification_20=='sustained':
        I = I*(cv_20<0.35)
    if classification_20=='transient':
        I = I*(cv_20>0.35)
    if PRINT_ACCEPTED:
        print 'Accepted %d points from %d generated, rate=%.2f' % (sum(I), N, sum(I)*1.0/N)
    if sum(I)*1.0/N>0.1:
        cv_20 = cv_20[I]
        cv_50 = cv_50[I]
        fr_20 = fr_20[I]
        fr_50 = fr_50[I]
        cv_diff = cv_50-cv_20
        fr_diff = fr_50-fr_20
        if show_absolute_plots:
            subplot(221)
            plot_density_map(fr_20, cv_20, mapdim, xmin=50, xmax=500, ymin=0.1, ymax=0.7)
            if showexpdata:
                plot(exp_fr20, exp_cv20, '.k', ms=4)
                #plot(exp_fr20, exp_cv20, '.', c=(0.5,)*4)
            if labelling:
                title('20 dB re threshold')
                xlabel('Firing rate (sp/s)')
                ylabel('CV')
            axhline(0.35, ls='--', c='k')
            subplot(222)
            plot_density_map(fr_50, cv_50, mapdim, xmin=50, xmax=500, ymin=0.1, ymax=0.7)
            if labelling:
                title('50 dB re threshold')
                xlabel('Firing rate (sp/s)')
                ylabel('CV')
            axhline(0.35, ls='--', c='k')
            if showexpdata:
                plot(exp_fr50, exp_cv50, '.k', ms=4)
                #plot(exp_fr50, exp_cv50, '.', c=(0.5,)*4)
            subplot(223)
        plot_density_map(fr_diff, cv_diff, mapdim, xmin=-150, xmax=150, ymin=-0.25, ymax=0.25)
        axhline(0, c='k')
        axvline(0, c='k')    
        delta_inh = delta_inh[I]
        delta_anf_rate = delta_anf_rate[I]
        def matching(delta_inh_mid, delta_anf_rate_mid, tolerance_inh=0.1, tolerance_anf_rate=40):
            J = delta_inh>delta_inh_mid-tolerance_inh
            J *= delta_inh<delta_inh_mid+tolerance_inh
            J *= delta_anf_rate>delta_anf_rate_mid-tolerance_anf_rate
            J *= delta_anf_rate<delta_anf_rate_mid+tolerance_anf_rate
            return mean(fr_diff[J]), mean(cv_diff[J])
        if showarrows:
            max_delta_inh = amax(delta_inh)*0.5
            min_delta_inh = amin(delta_inh)*0.5
            max_delta_anf_rate = amax(delta_anf_rate)*0.5
            try:
                annotate('', xytext=(0, 0), xy=matching(0, max_delta_anf_rate), annotation_clip=False,
                         arrowprops=dict(arrowstyle="-|>", fc='k', ec='k', ls='--', lw=lw), size=arrowsize)
            except ValueError:
                pass
            try:
                if min_delta_inh<0:
                    annotate('', xytext=(0, 0), xy=matching(min_delta_inh, 0), annotation_clip=False,
                             arrowprops=dict(arrowstyle="-|>", fc='k', ec='k', lw=lw), size=arrowsize)
            except ValueError:
                pass
            try:
                if max_delta_inh>0:
                    annotate('', xytext=(0, 0), xy=matching(max_delta_inh, 0), annotation_clip=False,
                             arrowprops=dict(arrowstyle="-|>", fc='k', ec='k', lw=lw), size=arrowsize)
            except ValueError:
                pass
        if labelling:
            if showarrows:
                x, y = matching(0, max_delta_anf_rate)
                try:
                    text(x, y, 'More excitation', va='center', ha='left',
                         path_effects=[PathEffects.withStroke(linewidth=2, foreground='w')])
                except ValueError:
                    pass
                try:
                    if min_delta_inh<0:
                        text(*(matching(min_delta_inh, 0)+('Less inhibition',)), verticalalignment='top', ha='center',
                             path_effects=[PathEffects.withStroke(linewidth=2, foreground='w')])
                except ValueError:
                    pass
                try:
                    if max_delta_inh>0:
                        text(*(matching(max_delta_inh, 0)+('More inhibition',)), verticalalignment='bottom', ha='center',
                             path_effects=[PathEffects.withStroke(linewidth=2, foreground='w')])
                except ValueError:
                    pass
            xlabel('Firing rate difference (sp/s)')
            ylabel('CV difference')
        else:
            xticks([])
            yticks([])
        if showexpdata:
            if classification_20=='sustained':
                J = exp_cv20<0.35
            elif classification_20=='transient':
                J = exp_cv20>0.35
            else:
                J = slice(None)
            #plot(exp_frdiff[J], exp_cvdiff[J], '.', c=(0.5,)*4)
            plot(exp_frdiff[J], exp_cvdiff[J], '.k', ms=4)
            xlim(-150, 150)
            ylim(-0.25, 0.25)
    else:
        gca().set_axis_bgcolor((0.7, 0.7, 0.7))

PRINT_ACCEPTED = True

# Set to 1 to generate Figure 7 from paper, set to 0 for online interactive version
if 0:
    N = 10000

    mu_dist = uniform(1, 5)
    num_anf_dist = uniform(3, 100)
    tau_dist = uniform(1, 10)
    inh_dist = uniform(0, 0.8)
    refractory_dist = uniform(0.0, 1.0)
    anf_rate_dist = uniform(100, 300)
    delta_anf_rate_dist = uniform(0, 160)
    delta_inh_dist = uniform(-0.3, 0.5)

    figure(figsize=(5, 4))
    sample_model_from_param_distributions(N, mu_dist=mu_dist, num_anf_dist=num_anf_dist, tau_dist=tau_dist,
                                          inh_dist=inh_dist, refractory_dist=refractory_dist,
                                          anf_rate_dist=anf_rate_dist, delta_anf_rate_dist=delta_anf_rate_dist,
                                          delta_inh_dist=delta_inh_dist,
                                          show_absolute_plots=False,)

    tight_layout()
    savefig('level_dependence_density.pdf')

In [ ]:
# This code just sets up an interactive version of the figure above

def interactive_density(N, mu, num_anf, tau, inh, refractory, anf_rate, delta_anf_rate, delta_inh):
    mu_dist = uniform(*mu)
    num_anf_dist = uniform(*num_anf)
    tau_dist = uniform(*tau)
    inh_dist = uniform(*inh)
    refractory_dist = uniform(*refractory)
    anf_rate_dist = uniform(*anf_rate)
    delta_anf_rate_dist = uniform(*delta_anf_rate)
    delta_inh_dist = uniform(*delta_inh)
    figure(figsize=(10, 8))
    sample_model_from_param_distributions(N, mu_dist=mu_dist, num_anf_dist=num_anf_dist, tau_dist=tau_dist,
                                          inh_dist=inh_dist, refractory_dist=refractory_dist,
                                          anf_rate_dist=anf_rate_dist, delta_anf_rate_dist=delta_anf_rate_dist,
                                          delta_inh_dist=delta_inh_dist,
                                          show_absolute_plots=True,)
    tight_layout()
    
interactive_widgets = no_continuous_update(ipw.interactive(interactive_density,
    N=ipw.IntSlider(description=r'Number of points to generate',
                    min=100, max=10000, step=100, value=1000),
    mu=ipw.FloatRangeSlider(description=r'Mean current at 20 dB $\mu$',
                            min=0.0, max=10, step=0.1, value=(1, 5)),
    num_anf=ipw.FloatRangeSlider(description=r'Number of ANF inputs $N$',
                                 min=1, max=150, step=1, value=(3, 100)),
    tau=ipw.FloatRangeSlider(description=r'Time constant $\tau$ (ms)',
                             min=0.1, max=30, step=0.1, value=(1, 10)),
    inh=ipw.FloatRangeSlider(description=r'Inhibitory fraction at 20 dB $\alpha$',
                             min=0, max=1, step=0.05, value=(0, 0.8)),
    refractory=ipw.FloatRangeSlider(description=r'Refractoriness $t_\mathrm{ref}$ (ms)',
                                    min=0, max=5, step=0.1, value=(0, 1)),
    anf_rate=ipw.FloatRangeSlider(description=r'ANF firing rate at 20 dB $\rho$ (sp/s)',
                                  min=10, max=500, step=10, value=(100, 300)),
    delta_anf_rate=ipw.FloatRangeSlider(description=r'Increase in ANF firing rate at 50 dB $\Delta\rho$ (sp/s)',
                                        min=0, max=500, step=10, value=(0, 160)),
    delta_inh=ipw.FloatRangeSlider(description=r'Change in inhibitory fraction at 50 dB $\Delta\alpha$',
                                   min=-1, max=1, step=0.05, value=(-0.3, 0.5)),
    ))
display(ipw.VBox(children=[interactive_widgets, progress_slider]))