Note. The following notebook contains code in addition to text and figures. By default, the code has been hidden. You can click the icon that looks like an eye in the toolbar above to show the code. To run the code, click the cell menu, then "run all".

Behaviour maps

In this notebook, we study the behaviour of the basic and reduced models introduced in the Basic Model notebook.

The figures show two panels, on the left is a colour map of the coefficient of variation (CV) of the interspike interval (ISI) histogram. The CV is defined as the ratio of the standard deviation to the mean. This is 0 for a perfectly regular spike train, and 1 for a Poisson spike train. The panel on the right shows a colour map of the firing rate.

In both panels, there are white contour lines. The solid contours are for the CV, and the dashed contours are for the firing rate. The hatched grey area in the left hand CV plot is the area where an insufficient number of spikes were produced to compute the CV.

The interface below needs a bit of explanation. By default, when you open the page, it will go to the first colour map figure in the paper, showing how CV and firing rate vary as a function of the parameters $\mu$ and $\sigma$ (although note that we use the ratio $\sigma/\mu$ instead of $\sigma$ because it makes the plot better). However, you can also reproduce all the panels in the second colour map figure by clicking on the tab "Model type and options". On this tab, change "Model type" to "2D map (basic model)" and choose the variables you want to vary on the horizontal and vertical axes. In the paper, the vertical axis is always $\mu$, and the horizontal axes are the four other variables (not including the refractory period). However, you can plot any variable against any other variable to see how the model behaves in different slices of the parameter space.

On the "Parameters and results" tabs there are various controls. There are standard sliders for changing parameters that aren't shown on the figure, but there are also "range sliders" which control the axis ranges for the horizontal and vertical axis parameters. You can use this to explore a larger part of the parameter space or to zoom in on an area. Note that in the current version of the software these sliders can be a bit fiddly and don't always update the plot. Just change another parameter if this happens to update it.

Finally, there is a "Quality" dropdown box. By default this is set to show results in a reasonable amount of time, but not very high quality. Change to "Fast" if you want very quick results, or to "Detailed" or "Publication" for higher quality results (the latter will takes tens of minutes to run). For longer runs, you can see the progress bar.

Finally, you can turn the smoothing on or off to be sure that the results are not misleading (this can happen at the "Fast" quality level).


In [ ]:
###### IMPORT AND UTILITY FUNCTIONS

%matplotlib inline
from brian2 import *
from model_explorer_jupyter import *
import joblib
import ipywidgets as ipw
from collections import OrderedDict
from scipy.ndimage.interpolation import zoom
from scipy.ndimage.filters import gaussian_filter
from matplotlib import cm
import matplotlib.patches as patches
import warnings
warnings.filterwarnings("ignore")
BrianLogger.log_level_error()

# Used for smoothing the plots, applies a Gaussian filter but works properly with nan values
def nan_gaussian_filter(x, sigma, num_passes):
    z = full_like(x, nan)
    for cursigma in linspace(sigma, 0, num_passes+1)[:-1]:
        y = gaussian_filter(x, cursigma, mode='nearest')
        z[isnan(z)] = y[isnan(z)]
    return z

progress_slider, update_progress = brian2_progress_reporter()

mem = joblib.Memory(cachedir="joblib", verbose=0)

In [ ]:
#### REDUCED MODEL
## Parameters: mu, sigma, tau, refractory

@mem.cache
def runmodel_raw(N, mu=2.0, sigma=0.1, tau_ms=6, refractory_ms=0.6,
                 repeats=50, duration=150*ms, skip=50*ms):
    tau = tau_ms*ms
    refractory = refractory_ms*ms

    eqs = '''
    dv/dt = (mu-v)/tau+sigma*tau**-0.5*xi : 1 (unless refractory)
    refrac : second
    tau : second
    mu : 1
    sigma : 1
    '''
    
    G = NeuronGroup(N, eqs, threshold='v>1', reset='v=0',
                    refractory='refrac', method='Euler')
    G.refrac = refractory
    G.tau = tau
    G.mu = mu
    G.sigma = sigma
    M = SpikeMonitor(G)
    
    run(duration, report=update_progress, report_period=1*second)
    
    CV = zeros(N)
    FR = zeros(N)
    trains = M.spike_trains()
    for i in xrange(N):
        t = trains[i]
        t = t[t>skip]
        t.sort() # for older Brian versions
        isi = diff(t)
        CV[i] = std(isi)/mean(isi)
        FR[i] = len(t)/(duration-skip)
    return CV, FR

def runmodel_raw_rel(N, mu=2.0, sigma_mu=0.1, tau_ms=6, refractory_ms=0.6,
                     repeats=50, duration=150*ms, skip=50*ms):
    return runmodel_raw(N, mu=mu, sigma=sigma_mu*mu, tau_ms=tau_ms,
                        refractory_ms=refractory_ms, repeats=repeats,
                        duration=duration, skip=skip)

In [ ]:
#### BASIC MODEL
## Parameters mu, N, tau, alpha, t_ref, rho_E

# Note that some of the parameter names are different to the paper
# because this code was written before the final set of names was
# settled on.

def runmodel_std(N, mu=2.0, num_anf=50, tau_ms=6,
                 inh=0.0, inh_skew=1.0,
                 refractory_ms=0.6, anf_rate_hz=250,
                 repeats=50, duration=150*ms, skip=50*ms):
    tau = tau_ms*ms
    refractory = refractory_ms*ms
    anf_rate = anf_rate_hz*Hz
    
    tau_exc = tau_inh = tau
    weight = mu/(num_anf*tau*anf_rate*(1-inh))
    weight_exc = weight
    weight_inh = inh_skew*weight
    num_anf_exc = num_anf
    num_anf_inh = num_anf/inh_skew
    anf_rate_exc = anf_rate
    anf_rate_inh = anf_rate*inh
    mu_exc = weight_exc*num_anf_exc*tau_exc*anf_rate_exc
    mu_inh = weight_inh*num_anf_inh*tau_inh*anf_rate_inh
    sigma2_exc = weight_exc*mu_exc
    sigma2_inh = weight_inh*mu_inh
    mu = mu_exc-mu_inh
    sigma = sqrt(sigma2_exc+sigma2_inh)
    return runmodel_raw(N, mu=mu, sigma=sigma, tau_ms=tau_ms,
                        refractory_ms=refractory_ms,
                        repeats=repeats, duration=duration, skip=skip)

In [ ]:
#### GUI AND PLOTTING CODE

# You don't need to understand most of the code in this section, all of the modelling code
# is in the cells above. You might want to take a look at the plotting code in the
# function that starts "def plotter(**kwds):".

# You're also very welcome to re-use this code to produce your own parameter exploration
# GUIs if you want!

sliders = OrderedDict([
        ('mu', ipw.FloatSlider(description=r"Total mean current $\mu$",
                               min=0.1, max=10, step=0.1, value=2.0)),
        ('tau_ms', ipw.FloatSlider(description=r"Membrane time constant $\tau$ (ms)",
                                   min=0.1, max=50.0, step=0.1, value=6)),
        ('num_anf', ipw.IntSlider(description=r"Number of auditory nerve fibres $N$",
                                  min=1, max=100, step=1, value=50)),
        ('inh', ipw.FloatSlider(description=r"Inhibitory fraction $\alpha$",
                                min=0, max=1, step=0.1, value=0)),
        ('refractory_ms', ipw.FloatSlider(description=r"Refractory period $t_\mathrm{ref}$ (ms)",
                                          min=0, max=5, step=0.1, value=0.6)),
        ('anf_rate_hz', ipw.FloatSlider(description=r"Auditory nerve firing rate $\rho_E$ (sp/s)",
                                        min=10, max=500, step=10, value=250)),
        ('sigma_mu', ipw.FloatSlider(description=r"Total noise relative to mean $\sigma/\mu$",
                                     min=0, max=2, step=0.05, value=0.1)),
        ])
range_sliders = OrderedDict([
        ('mu', ipw.FloatRangeSlider(description=r"Total mean current $\mu$",
                                    min=0.1, max=10, step=0.1, value=(0.5, 3.0))),
        ('tau_ms', ipw.FloatRangeSlider(description=r"Membrane time constant $\tau$ (ms)",
                                        min=0.1, max=50.0, step=0.1, value=(1, 10))),
        ('num_anf', ipw.IntRangeSlider(description=r"Number of auditory nerve fibres $N$",
                                       min=1, max=100, step=1, value=(3, 50))),
        ('inh', ipw.FloatRangeSlider(description=r"Inhibitory fraction $\alpha$",
                                     min=0, max=1, step=0.1, value=(0, 0.8))),
        ('refractory_ms', ipw.FloatRangeSlider(description=r"Refractory period $t_\mathrm{ref}$ (ms)",
                                               min=0, max=5, step=0.1, value=(0.1, 1.0))),
        ('anf_rate_hz', ipw.FloatRangeSlider(description=r"Auditory nerve firing rate $\rho_E$ (sp/s)",
                                             min=10, max=500, step=10, value=(50, 400))),
        ('sigma_mu', ipw.FloatRangeSlider(description=r"Total noise relative to mean $\sigma/\mu$",
                                          min=0, max=2, step=0.05, value=(0, 1))),
        ])
quality_slider = ipw.Dropdown(description="Quality",
                              options=["Fast", "Normal", "Detailed", "Publication (very slow)"],
                              value='Normal')
for slider in sliders.values()+range_sliders.values():
    slider.layout.width = '100%'
    slider.style = {'description_width': '30%'}

def savecurfig(fname):
    curfig.savefig(fname)
widget_savefig = save_fig_widget(savecurfig)

vars_std = OrderedDict((k, v.description) for k, v in sliders.items() if k!='sigma_mu')
vars_raw = OrderedDict((k, sliders[k].description) for k in ['mu', 'sigma_mu', 'tau_ms', 'refractory_ms'])

vs2d_std = VariableSelector(vars_std, ['Horizontal axis', 'Vertical axis'], title=None,
                            initial={'Horizontal axis': 'inh', 'Vertical axis': 'mu'})
vs2d_raw = VariableSelector(vars_raw, ['Horizontal axis', 'Vertical axis'], title=None,
                            initial={'Horizontal axis': 'sigma_mu', 'Vertical axis': 'mu'})

options2d_std = {'var': vs2d_std.widgets_vertical}
options2d_raw = {'var': vs2d_raw.widgets_vertical}

def plot2d(modelfunc, vs2d):
    def plotter(**kwds):
        global curfig
        smoothing = kwds.pop('smoothing')
        quality_settings = {'Fast': (5, 0, 50*ms, 100*ms),
                            'Normal': (20, 0.1, 50*ms, 200*ms),
                            'Detailed': (40, 0.05, 100*ms, 1000*ms),
                            'Publication (very slow)': (50, 0.05, 1*second, 50*second)}
        M, blur_width, skip, duration = quality_settings[kwds.pop('quality')]
        # Set up ranges of variables, and generate arguments to pass to model function
        axis_ranges = dict((k, linspace(*(v+(M,)))) for k, v in kwds.items() if k in vs2d.selected)
        array_kwds = meshed_arguments(vs2d.selected, kwds, axis_ranges)
        shape = array_kwds[vs2d.selection['Horizontal axis']].shape
        N = array_kwds[vs2d.selection['Horizontal axis']].size
        array_kwds[vs2d.selection['Horizontal axis']].shape = N
        array_kwds[vs2d.selection['Vertical axis']].shape = N
        # Run the model
        CV, FR = modelfunc(N, skip=skip, duration=duration, **array_kwds)
        # Unflatten the results
        CV.shape = shape
        FR.shape = shape
        # Apply smoothing to improve the appearance if desired
        if smoothing:
            if blur_width:
                FR = nan_gaussian_filter(FR, blur_width*M, 10)
                CV = nan_gaussian_filter(CV, blur_width*M, 10)
            FR = zoom(FR, 100./M, order=1)
            CV = zoom(CV, 100./M, order=1)
        # The rest is plotting code
        extent = (kwds[vs2d.selection['Horizontal axis']]+
                  kwds[vs2d.selection['Vertical axis']])
        def labelit(titletext):
            title(titletext)
            xlabel(sliders[vs2d.selection['Horizontal axis']].description)
            ylabel(sliders[vs2d.selection['Vertical axis']].description)
            cb = colorbar()
            cb.set_label(titletext, rotation=270, labelpad=20)    
            cs_cv = contour(CV, origin='lower', aspect='auto',
                            levels=[0.15, 0.35, 0.8],
                            colors='w',
                            extent=extent,
                            )
            clabel(cs_cv, colors='w', inline=True, fmt='%.2f')
            cs_rate = contour(FR, origin='lower', aspect='auto', linestyles='dashed',
                              levels=[50, 100, 250, 400], colors='w',
                              extent=extent,
                              )
            clabel(cs_rate, colors='w', inline=True, fmt='%d sp/s')
        curfig = figure(1, figsize=(12, 5))
        clf()
        subplot(121)
        imshow(CV, origin='lower left', aspect='auto',
               interpolation='nearest',
               cmap=cm.YlGnBu_r, vmin=0, vmax=1,
               extent=extent)
        # Hatched background for areas where CV couldn't be computed
        p = patches.Rectangle((extent[0], extent[2]), extent[1]-extent[0], extent[3]-extent[2],
                              hatch='xxx', fill=True, fc=(0.7,)*3, ec=(0.6,)*3, zorder=-10)
        gca().add_patch(p)        
        labelit('CV')
        subplot(122)
        imshow(FR, origin='lower left', aspect='auto',
               interpolation='nearest',
               cmap=cm.YlGnBu_r, vmin=0, vmax=450,
               extent=extent)
        labelit('Firing rate (sp/s)')
        tight_layout()
    return plotter
    
def map2d(runmodel, vs2d):
    def f():
        params = vs2d.merge_selected(range_sliders, sliders)
        params['quality'] = quality_slider
        params['smoothing'] = True
        i = ipw.interactive(plot2d(runmodel, vs2d), **params)
        return no_continuous_update(i)
    return f

models = [('2D Map (reduced model)', map2d(runmodel_raw_rel, vs2d_raw),
               options2d_raw, [widget_savefig, progress_slider]),
          ('2D Map (basic model)', map2d(runmodel_std, vs2d_std),
               options2d_std, [widget_savefig, progress_slider]),
         ]

# Create model explorer, and jump immediately to results page
modex = model_explorer(models)
# modex.widget_model_type.value = '2D Map'
modex.tabs.selected_index = 1
display(modex)