Finding simulation parameters that best match experimental spiking profiles


In [ ]:
%matplotlib inline
import matplotlib as mpl
import matplotlib.ticker as mtick
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('white')

plt.rc('text', usetex=True)
plt.rc('font', family='serif', serif='cm')

plt.rcParams['figure.titlesize'] = 10
plt.rcParams['axes.labelsize'] = 8
plt.rcParams['axes.titlesize'] = 10
plt.rcParams['xtick.labelsize'] = 8
plt.rcParams['ytick.labelsize'] = 8
plt.rcParams['axes.labelpad'] = 3.0

from IPython.display import display, clear_output
from ipywidgets import FloatProgress

# comment out the next line if not working on a retina-display computer
import IPython
IPython.display.set_matplotlib_formats('retina')

In [ ]:
%load_ext autoreload
%autoreload 2

In [ ]:
import numpy as np
import scipy.io
import scipy.optimize as optimize
import scipy.stats as stats
import copy
import time
import os
import cPickle as pickle

import cma

In [ ]:
import simulation
from basic_defs import *
from helpers import *

General definitions


In [ ]:
def lognormal_to_meanstdev(mu, sigma):
    """ Find lognormal mean and standard deviation given mu and sigma
    parameters.
    
    Parameters
    ----------
      mu, sigma
          Mean and standard deviation of associated normal distribution.
    
    Returns
    -------
      (mean, stdev)
          Mean and standard deviation of lognormal distribution.
    """
    return (np.exp(mu + sigma**2/2), 
            np.sqrt((np.exp(sigma**2) - 1)*np.exp(2*mu + sigma**2)))

In [ ]:
def find_bursts(isi, percentile_threshold_in=65, percentile_threshold_out=80, min_spikes=3,
                isi_in=1.0/80.0, isi_out=1.0/40.0, return_idxs=False):
    """ Find the bursts in a vector ISIs.
    
    Bursts are returned as a vector of tuples, (start, end), giving times of burst start
    and end using the first spike in the recording as a reference point.
    
    Change `percentile_threshold_in` and `percentile_threshold_out` to set the threshold for
    entering a burst and exiting one, respectively. Low percentile means low ISI threshold.
    The values are in percent.
    
    If `isi_in` and `isi_out` are not `None`, they are used instead.
    
    A burst will only be recorded if it has at least `min_spikes` spikes.
    
    If `return_idxs` is `True`, the function also returns the pairs of indices where bursts
    start and end.
    """
    isi = np.asarray(isi)
    
    if isi_in is None:
        isi_threshold_in = np.percentile(isi, percentile_threshold_in)
    else:
        isi_threshold_in = isi_in
    
    if isi_out is None:
        isi_threshold_out = np.percentile(isi, percentile_threshold_out)
    else:
        isi_threshold_out = isi_out
    
    # places where bursts may start
    starts = (isi < isi_threshold_in).nonzero()[0]
    # ...and end
    ends = (isi >= isi_threshold_out).nonzero()[0]
    
    # find the first burst that starts, then find the first ending after that,
    # then move to the next starting point, etc.
    if len(starts) > 0:
        crt_start = starts[0]
        res0 = []
        res = []
        
        # res0 contains start and end *indices*
        # transform them to times
        times = np.hstack(([0], np.cumsum(isi)))
        
        while True:
            idx = (ends > crt_start).nonzero()[0]
            if len(idx) == 0:
                # we're done, burst ends at end of simulation
                res0.append((crt_start, len(isi)))
                break

            crt_end = ends[idx[0]]
            
            if crt_end - crt_start + 1 >= min_spikes:
                res0.append((crt_start, crt_end))
                res.append((times[crt_start], times[crt_end]))
            idx = (starts > crt_end).nonzero()[0]
            if len(idx) == 0:
                # no more bursts
                break
            
            crt_start = starts[idx[0]]
    else:
        # no bursts
        res = []
        res0 = []
    
    if not return_idxs:
        return res
    else:
        return (res, res0)

In [ ]:
def calculate_statistics(isi, max_trustworthy=1.0):
    """ Calculate several summary statistics about the spikes that generated the given vector
    of inter-spike intervals (ISIs).
    
    The `isi` vector is in assumed to be in seconds. ISIs that go above `max_trustworthy` are
    ignored -- these are assumed to be inter-trial intervals.
    """
    isi = isi[isi <= max_trustworthy]
    
    t_max = float(np.sum(isi))
    n_spikes = 1 + len(isi)
    
    if t_max > 0:
        firing_rate = n_spikes/t_max
    else:
        firing_rate = 0.0
    
    if len(isi) > 0:
        isi_mean = np.mean(isi)
        isi_std = np.std(isi)
        isi_skew = stats.skew(isi, bias=False)
        if isi_mean > 0:
            fano_factor = np.var(isi) / isi_mean
            cv = isi_std / isi_mean
        else:
            fano_factor = 0.0
            cv = 0.0
    else:
        isi_mean = 0.0
        isi_std = 0.0
        fano_factor = 0.0
        cv = 0.0
    
    bursts, bursts_idxs = find_bursts(isi, return_idxs=True)
    if t_max > 0:
        burst_rate = len(bursts)/t_max
    else:
        burst_rate = 0.0
        
    burst_lengths = np.asarray([_[1] - _[0] for _ in bursts])
    
    if len(burst_lengths) > 0:
        burst_mean_length = np.mean(burst_lengths)
    else:
        burst_mean_length = 0.0
    
    # number of spikes for each burst
    burst_nspikes = np.asarray([_[1] - _[0] + 1 for _ in bursts_idxs])
    burst_total_time = np.sum(burst_lengths)
    if burst_total_time > 0:
        firing_rate_bursts = np.sum(burst_nspikes)/burst_total_time
    else:
        firing_rate_bursts = 0.0

    return {
        'firing_rate': firing_rate,
        'isi_mean': isi_mean,
        'isi_std': isi_std,
        'isi_skew': isi_skew,
        'fano_factor': fano_factor,
        'cv': cv,
        'burst_rate': burst_rate,
        'burst_mean_length': burst_mean_length,
        'firing_rate_bursts': firing_rate_bursts
    }

In [ ]:
class SimulationStatistician(object):
    
    """ A class to simplify the job of calculating spiking statistics from a
    simulation with a given set of parameters.
    """
    
    def __init__(self, tmax, dt, **kwargs):
        """ Set the parameters of the simulation.
        
        Any extra arguments apart from `n_reps` get passed directly to
        `SpikingLearningSimulation`.
        """
        # the simulator class needs a target, but it's irrelevant for our purposes
        self.tmax = tmax
        self.dt = dt
        self.target = np.zeros((1, int_r(self.tmax/self.dt)))
        self.n_reps = kwargs.pop('n_reps', 1)
        
        def tracker_generator(simulator, i, n):
            """ Generate some trackers. """
            res = {}
            res['student_spike'] = simulation.EventMonitor(simulator.student)

            return res
        
        default_args = {
            'n_student_per_output': 10,
            'plasticity_learning_rate': 0,
            'tutor_rule_gain': 0,
            'tracker_generator': tracker_generator
        }        
        default_args.update(kwargs)
        
        self.default_args = default_args
    
    def __call__(self, **kwargs):
        """ Run the simulation and return the statistics.
        
        If the keyword argument `combine_reps` is `False` (the default), the
        function returns a list of dictionaries, one for each repetition of the
        simulation. Otherwise, all the spikes get collected in one vector whose
        statistics are returned.
        
        Any other keyword arguments are passed to `SpikingLearningSimulation`.
        """
        combine_reps = kwargs.pop('combine_reps', False)
        
        self.args = dict(self.default_args)
        self.args.update(kwargs)
        
        self.last_simulator = SpikingLearningSimulation(self.target, self.tmax, self.dt,
                                                        **self.args)
        self.last_res = self.last_simulator.run(self.n_reps)
        
        if not combine_reps:
            stats = []
            for crt_res in self.last_res:
                crt_isi = collect_isi([crt_res['student_spike']], tmax=self.tmax)
                # division to convert from ms to s!
                stats.append(calculate_statistics(crt_isi/1000.0))
                
            return stats
        else:
            # division to convert from ms to s!
            return calculate_statistics(collect_isi([_['student_spike'] for _ in self.last_res],
                                                    tmax=self.tmax)/1000.0)

In [ ]:
def average_stats(stats_list):
    """ Take the average over the values of a list of dictionaries.
    
    Values that cannot be averaged are not included in the final result.
    """
    keys = stats_list[0].keys()
    res = {}
    n = float(len(stats_list))
    for crt_key in keys:
        crt_values = [_[crt_key] for _ in stats_list]
        try:
            crt_avg = sum(crt_values)/n
            res[crt_key] = crt_avg
        except:
            pass
        
    return res

In [ ]:
def draw_summary_stats_violins(epochs, all_stats, all_ages, simstats, title_map, order,
                               color=[0.831, 0.333, 0.000],
                               plot_min_age=43, plot_max_age=160, simpos='left'):
    """ Draw a violin plot comparing different bird ages and simulation results statistics.
    """
    n_plots = len(order)
    nx = int(math.ceil(math.sqrt(n_plots)))
    ny = int(math.ceil(float(n_plots)/nx))
    
    epoch_stats = []
    for crt_epoch in epochs:
        epoch_stats.append([crt_stats for (crt_stats, crt_age) in zip(all_stats, all_ages)
                            if crt_age >= crt_epoch[0] and crt_age < crt_epoch[1]])
    
    for (i, crtplot) in enumerate(order):
        data = []
        colors = []
        names = []
        plt.subplot(ny, nx, 1+i)
        for (crt_epoch, crt_stats) in zip(epochs, epoch_stats):
            values = np.asarray([x[crtplot] for x in crt_stats])
            crt_dph = np.mean(crt_epoch)

            data.append(values)
            colors.append(color)
            names.append('{}-{}'.format(*crt_epoch))
        
        # now append the simulated values
        simvalues = np.asarray([x[crtplot] for x in simstats])
        simcolor = (0.2, 0.5, 1.0)
        simtext = 'sim'
        if simpos == 'right':
            data.append(simvalues)
            colors.append(simcolor)
            names.append(simtext)
        elif simpos == 'left':
            data.insert(0, simvalues)
            colors.insert(0, simcolor)
            names.insert(0, simtext)
        else:
            raise Exception('Unknown simpos.')
        
        # and draw!
        sns.violinplot(data=data, palette=colors, scale='width')
        plt.xticks(range(len(names)), names, rotation=45)
        plt.ylim(0, plt.ylim()[1])

        plt.grid(True)
        
        plt.ylabel(title_map[crtplot])
        
#        plt.gca().title.set_fontsize(10)
        plt.gca().yaxis.label.set_fontsize(8)
        for _ in plt.gca().get_xticklabels() + plt.gca().get_yticklabels():
            _.set_fontsize(8)

In [ ]:
def draw_summary_stats_violins_double(epochs, all_stats, all_ages, simstats1, simstats2, 
                               title_map, order,
                               color=[0.831, 0.333, 0.000],
                               plot_min_age=43, plot_max_age=160, simtext=['sim1', 'sim2']):
    """ Draw a violin plot comparing different bird ages and simulation results statistics.
    """
    n_plots = len(order)
    nx = int(math.ceil(math.sqrt(n_plots)))
    ny = int(math.ceil(float(n_plots)/nx))
    
    epoch_stats = []
    for crt_epoch in epochs:
        epoch_stats.append([crt_stats for (crt_stats, crt_age) in zip(all_stats, all_ages)
                            if crt_age >= crt_epoch[0] and crt_age < crt_epoch[1]])
    
    for (i, crtplot) in enumerate(order):
        data = []
        colors = []
        names = []
        plt.subplot(ny, nx, 1+i)
        for (crt_epoch, crt_stats) in zip(epochs, epoch_stats):
            values = np.asarray([x[crtplot] for x in crt_stats])
            crt_dph = np.mean(crt_epoch)

            data.append(values)
            colors.append(color)
            names.append('{}-{}'.format(*crt_epoch))
        
        # now append the simulated values
        simvalues1 = np.asarray([x[crtplot] for x in simstats1])
        simvalues2 = np.asarray([x[crtplot] for x in simstats2])
        simcolor = (0.2, 0.5, 1.0)

        data.append(simvalues2)
        colors.append(simcolor)
        names.append(simtext[1])

        data.insert(0, simvalues1)
        colors.insert(0, simcolor)
        names.insert(0, simtext[0])
        
        # and draw!
        sns.violinplot(data=data, palette=colors, scale='width')
        plt.xticks(range(len(names)), names, rotation=45)
        plt.ylim(0, plt.ylim()[1])

        plt.grid(True)
        
        plt.ylabel(title_map[crtplot])
        
        plt.gca().yaxis.label.set_fontsize(8)
        for _ in plt.gca().get_xticklabels() + plt.gca().get_yticklabels():
            _.set_fontsize(8)

Load experimental spiking data


In [ ]:
spike_stats_raw = scipy.io.loadmat("data/allCellsNew.mat",
                                   squeeze_me=True, chars_as_strings=True,
                                   struct_as_record=False)
all_data = spike_stats_raw['allCellsNew']
isi_data = np.asarray([_.spikeI for _ in all_data])
ages = np.asarray([x.dph for x in all_data])

Select data below 200 dph, because the 200 dph birds are all directed singing.


In [ ]:
selected_mask = (ages < 200)
selected_isi_data = isi_data[selected_mask]
selected_ages = ages[selected_mask]

Calculate statistics of experimental data


In [ ]:
all_stats = np.asarray([calculate_statistics(crt_isi) for crt_isi in isi_data])
selected_stats = all_stats[selected_mask]

In [ ]:
stats_title_map = {
    'burst_mean_length': 'Mean duration of bursts (s)',
    'burst_rate': 'Rate at which bursts occur (Hz)',
    'cv': 'Coefficient of variation',
    'fano_factor': 'Fano factor (s)',
    'firing_rate': 'Average firing rate (Hz)',
    'firing_rate_bursts': 'Average firing rate during bursts (Hz)',
    'isi_mean': 'Mean inter-spike interval (s)',
    'isi_std': 'Standard deviation of inter-spike intervals (s)',
    'isi_skew': 'Skewness of inter-spike interval distribution'
}

Make linear predictor of age based on stats


In [ ]:
predictor_names = ['constant'] + stats_title_map.keys()
selected_summary_stats = np.column_stack(
    [(x[name] if name != 'constant' else 1) for x in selected_stats]
    for name in predictor_names)
lincoeffs, linresiduals, _, _ = np.linalg.lstsq(selected_summary_stats, np.asarray(selected_ages, dtype=float))
lincoeffs_mapped = {name: lincoeffs[i] for (i, name) in enumerate(predictor_names)}
guessed_selected_ages = np.dot(selected_summary_stats, lincoeffs)
print('Size of residuals is {:.2f} dph.'.format(np.sqrt(linresiduals[0]/(len(selected_ages)-1))))

In [ ]:
plt.scatter(selected_ages, guessed_selected_ages, alpha=0.6)
plt.scatter([selected_ages[68]], [guessed_selected_ages[68]], c='r')   # juvenile
plt.scatter([selected_ages[122]], [guessed_selected_ages[122]], c='g') # adult
plt.xlabel("Real age (dph)")
plt.ylabel("Age inferred from linear model (dph)")
print "Correlation coefficient: {:.2f}.".format(np.corrcoef(selected_ages, guessed_selected_ages)[0, 1])

Find summary statistics that are very similar to one-another.


In [ ]:
sns.heatmap(np.corrcoef(selected_summary_stats.T), square=True, vmin=-1.0, vmax=1.0)
predictor_names_disp = [s.replace('_', ' ') for s in predictor_names]
plt.xticks(0.5 + np.arange(len(predictor_names_disp)), predictor_names_disp, rotation='vertical')
# XXX seaborn seems to flip the labels but not the coordinates!
plt.yticks(len(predictor_names_disp) - 0.5 - np.arange(len(predictor_names_disp)),
           predictor_names_disp, rotation='horizontal');

Search for simulation parameters that best model measured spiking

Useful functions for optimization


In [ ]:
def pack_params(d, selection):
    """ Convert a subset of the values in the dictionary `d` (selected
    by the list `selection`) into a vector.
    """
    res = []
    for crt_name in selection:
        # handle multi-component values
        i = crt_name.find('::')
        if i == -1:
            res.append(d[crt_name])
        else:
            crt_value = d[crt_name]
            if not hasattr(crt_value, '__len__') or len(crt_value) != int(crt_name[i+2:]):
                raise Exception('Mismatch between name and length of value.')
            
            res.extend(crt_value)
    
    return np.asarray(res)

In [ ]:
def unpack_params(x, selection):
    """ Convert the values from `pack_params` back to a dictionary. """
    res = {}
    k = 0
    for crt_name in selection:
        i = crt_name.find('::')
        if i == -1:
            res[crt_name] = x[k]
            k += 1
        else:
            crt_n = int(crt_name[i+2:])
            res[crt_name] = tuple(x[k:k+crt_n])
            k += crt_n
    
    return res

In [ ]:
def sanitize_keys(d):
    """ 'Sanitize' the keys in the dictionary by removing any trailing '::<number>'
    parts.
    """
    res = {}
    for key, value in d.items():
        i = key.find('::')
        if i < 0:
            res[key] = value
        else:
            res[key[:i]] = value
    
    return res

In [ ]:
def get_bounds_trafo(bounds_dict, selection):
    """ Generate 'genotype' to 'phenotype' and reverse transformations
    for use with CMA-ES.
    """
    bounds = pack_params(bounds_dict, selection)
    def bounds_trafo(x):
        return bounds[:, 0] + 0.5*(bounds[:, 1] - bounds[:, 0])*(1.0 - np.cos(np.pi*x/10.0))

    def bounds_rev_trafo(y):
        return 10.0/np.pi*np.arccos(1.0 - 2.0*(y - bounds[:, 0])/(bounds[:, 1] - bounds[:, 0]))
    
    return (bounds_trafo, bounds_rev_trafo)

Defining parameters to optimize and objective function


In [ ]:
# choose statistics to optimize
optim_stats = ['firing_rate', 'cv', 'isi_skew', 'burst_rate', 'burst_mean_length', 'firing_rate_bursts']

In [ ]:
def stats_diff(stats1, stats2, optim_stats):
    """ Calculate the difference between two stats. """
    ratios = [stats1[key] / stats2[key] for key in optim_stats]
    return np.linalg.norm(np.asarray(ratios) - 1.0)/np.sqrt(len(ratios))

In [ ]:
class ObjectiveFunction(object):
    
    """ An objective function that calculates spiking statistics, for use
    with CMA-ES.
    """
    
    def __init__(self, *args, **kwargs):
        """ Initialize the objective function.
        
        Arguments
        ---------
          target_stats: dictionary
              Target values for the statistics.
          optim_stats: list
              Select the statistics that are optimized.
          selection: list
              Parameters that are being optimized.
          pack_fct:
          unpack_fct:
              Functions that pack and unpack the parameters between vector
              form and dictionary form. Their signatures are
                  pack_fct(dict, selection)
                  unpack_fct(x, selection)
        
        All other arguments are passed directly to `SimulationStatistician`.
        """
        self.target_stats = kwargs.pop('target_stats')
        self.optim_stats = kwargs.pop('optim_stats')
        self.selection = kwargs.pop('selection')
        self.defaults = kwargs.pop('defaults')
        
        self.pack_fct = kwargs.pop('pack_fct')
        self.unpack_fct = kwargs.pop('unpack_fct')
        
        self.statistician = SimulationStatistician(*args, **kwargs)
        
    def __call__(self, x):
        """ Calculate how different the statistics of the simulation given the
        parameters `x` are from the target statistics.
        
        Parameters
        ----------
          x
              Numeric array describing the parameters for the simulation.
        
        Returns
        -------
          A single number representing a measure of how far the current simulation is
          from the target, in terms of the summary statistics identified by `self.optim_stats`.
          
          Given the current and target values of the statistics, `crt` and `tgt`, the
          output is
            `np.linalg.norm(crt/tgt - 1.0)/len(crt)`
        """
        args = dict(self.defaults)
        args.update(self.unpack_fct(x, self.selection))
        
        args = sanitize_keys(args)
        
        stats = self.statistician(**args)
        if len(stats) > 1:
            avg_stats = average_stats(stats)
        else:
            avg_stats = stats[0]
        
        self.avg_stats = avg_stats
        
        return stats_diff(self.avg_stats, self.target_stats, self.optim_stats)

In [ ]:
def split_args(args, i):
    """ Keep the arguments that are common (i.e., have no '##' in their names)
    or belong to simulation `i` (i.e., have '##' + str(i) in their names).
    """
    crt_args = {}
    for key, value in args.items():
        # get rid of any length indicators
        pound_j = key.find('::')
        if pound_j >= 0:
            key = key[:pound_j]
        pound_i = key.find('##')
        if pound_i < 0:
            # a common argument
            crt_args[key] = value
        elif int(key[pound_i+2:]) == i:
            # keep only the arguments meant for this simulation
            crt_args[key[:pound_i]] = value
    
    return crt_args

In [ ]:
def merge_dicts(d1, d2):
    """ Merge the two dictionaries, keeping values from the second
    one whenever there is overlap of keys.
    """
    d = dict(d1)
    d.update(d2)
    return d

In [ ]:
class MultipleObjectiveFunction(object):

    """ An objective function that optimizes the results of several
    simulations instead of one.
    
    The simulations share some parameters, while others are separate.
    """
    
    def __init__(self, *args, **kwargs):
        """ Initialize the objective function.
        
        Arguments
        ---------
          target_stats: list of dictionaries
              Target values for the statistics.
          optim_stats: list
              Select the statistics that are optimized.
          selection: list
              Parameters that are being optimized. Parameters that have names
              of the form 'param##0' or 'param##2::3' are specific to a given
              simulation (index 0 or 2, respectively, for these examples).
              Other parameters are common for all simulations. Note that the
              indication for vector parameters ('::3' here) comes after the
              simulation index.
          defaults: dict
              Default parameters that are used if they are not replaced by any
              of the values from `selection`.
          pack_fct:
          unpack_fct:
              Functions that pack and unpack parameters between vector
              form and dictionary form. Their signatures are
                  pack_fct(dict, selection)
                  unpack_fct(x, selection)
        
        All other arguments are passed directly to `SimulationStatistician`.
        """
        self.target_stats = kwargs.pop('target_stats')
        self.optim_stats = kwargs.pop('optim_stats')
        self.selection = kwargs.pop('selection')
        self.defaults = kwargs.pop('defaults')
        
        self.pack_fct = kwargs.pop('pack_fct')
        self.unpack_fct = kwargs.pop('unpack_fct')
        
        self.n = len(self.target_stats)
        self.statisticians = [SimulationStatistician(*args, **kwargs) for _ in xrange(self.n)]
        
    def __call__(self, x):
        """ Calculate how different the statistics of the simulations given the
        parameters `x` are from the target statistics.
        
        Parameters
        ----------
          x
              Numeric array describing the parameters for the simulations.
        
        Returns
        -------
          A single number representing a measure of how far the current simulations are
          from their targets, in terms of the summary statistics identified by
          `self.optim_stats`.
          
          Given the current and target values of the statistics, `crt[i]` and `tgt[i]`, the
          output is
            `np.mean(np.linalg.norm(crt/tgt - 1.0)/len(crt))`
        """
        args = dict(self.defaults)
        args.update(self.unpack_fct(x, self.selection))
        args = sanitize_keys(args)
        
        avg_stats = []
        
        for i, crt_statistician in enumerate(self.statisticians):
            crt_args = split_args(args, i)
            
            crt_stats = crt_statistician(**crt_args)
            if len(crt_stats) > 1:
                crt_stats = average_stats(crt_stats)
            else:
                crt_stats = crt_stats[0]
            
            avg_stats.append(crt_stats)
        
        self.avg_stats = avg_stats
        
        diffs = [stats_diff(crt, tgt, self.optim_stats) for (crt, tgt) in
                 zip(self.avg_stats, self.target_stats)]
        
        return np.mean(diffs)

Use CMA-ES to jointly optimize juvenile and adult simulations


In [ ]:
# choose recordings to use for juvenile and adult bird, respectively
idx_juvenile = 68
idx_adult = 122

In [ ]:
# choose ranges for common parameters
optim_range = {    
    'conductor_rate_during_burst': (400, 800),
    'student_vR': (-75.0, -65.0),
    'student_v_th': (-55.0, -45.0),
    'student_R': (100.0, 400.0),
    'student_tau_m': (10.0, 25.0),
    'student_tau_ampa': (3.0, 15.0),
    'student_tau_nmda': (80.0, 120.0),
    'student_g_inh': (0.1, 2.0),
#    'student_i_external': (-0.5, -0.05),
    'student_tau_ref': (1.0, 2.0),
    'cs_weights_params##0::2': ((-3.62, -3.52), (0.50, 0.58)),
    'cs_weights_params##1::2': ((-2.87, -2.71), (0.71, 0.81)),
    'cs_weights_fraction##0': (0.06, 0.5),
    'cs_weights_fraction##1': (0.03, 0.5),
    'ts_weights': (0.10, 0.20)
}
optim_params = ['conductor_rate_during_burst', 'student_v_th', 'student_R', 'student_vR',
                'student_tau_m', 'student_tau_ampa',
                'student_tau_nmda', 'student_tau_ref', #'student_i_external',
                'student_g_inh',
                'cs_weights_params##0::2', 'cs_weights_params##1::2',
                'cs_weights_fraction##0', 'cs_weights_fraction##1', 'ts_weights']

In [ ]:
# choose a starting point (guess values)
guess_params = {
    'conductor_rate_during_burst': 718.0,
    'cs_weights_fraction##0': 0.4,
    'cs_weights_fraction##1': 0.17,
    'cs_weights_params##0::2': (-3.54, 0.54),
    'cs_weights_params##1::2': (-2.72, 0.76),
    'student_R': 400.0,
    'student_g_inh': 1.48,
    'student_tau_ampa': 4.8,
    'student_tau_m': 24.5,
    'student_tau_nmda': 94.4,
    'student_tau_ref': 1.9,
    'student_vR': -73.7,
    'student_v_th': -53.1,
    'ts_weights': 0.12
}

In [ ]:
# length and timestep of simulation
tmax = 600.0
dt = 0.2

objective = MultipleObjectiveFunction(tmax, dt,
                              target_stats=[selected_stats[idx_juvenile], selected_stats[idx_adult]],
                              optim_stats=optim_stats, selection=optim_params,
                              defaults=guess_params,
                              pack_fct=pack_params, unpack_fct=unpack_params,
                              n_conductor=300, n_student_per_output=100,
                              relaxation=400.0, relaxation_conductor=25.0,
                              plasticity_constrain_positive=False)

In [ ]:
bounds_trafos = get_bounds_trafo(optim_range, optim_params)
# run the optimization -- this can take a while
cma_res = cma.fmin(
        objective, pack_params(guess_params, optim_params), sigma0=3.0,
        options={'tolfun': 1e-3, 'tolfunhist': 1e-4, 'tolx': 1e-4,
             'transformation': bounds_trafos,
             'maxfevals': 1000, 'verb_disp': 1}
    )

In [ ]:
best_params = unpack_params(cma_res[0], optim_params)

In [ ]:
# save parameters to file
best_params_full = merge_dicts(objective.statisticians[0].default_args,
           sanitize_keys(merge_dicts(objective.defaults, best_params)))

# get rid of tracker/snapshot generators
best_params_full.pop('tracker_generator', None)
best_params_full.pop('snapshot_generator', None)

# include some extra data
best_params_full['tmax'] = tmax
best_params_full['dt'] = dt

# do the actual saving
imax = 1000
for i in xrange(imax):
    # save to the first available slot, name ending in _0, _1, _2, ...
    file_name = 'best_params_joint_{}.pkl'.format(i)
    if os.path.exists(file_name):
        continue
    
    with open(file_name, 'wb') as out:
        pickle.dump(best_params_full, out, 2)
    
    break
    
# copy the resulting file to default_params.pkl to be used by the figure-making code
# below and by the simulation code in spiking_simulations.ipynb

In [ ]:
cma_res[-1].load()
plt.plot(np.min(cma_res[-1].f, axis=1))
plt.xlabel('error')

Check the spiking patterns for joint optimization


In [ ]:
args_juvenile = merge_dicts(objective.statisticians[0].default_args,
                           split_args(merge_dicts(guess_params, best_params), 0))
args_adult = merge_dicts(objective.statisticians[1].default_args,
                        split_args(merge_dicts(guess_params, best_params), 1))

In [ ]:
sim_juvenile = SpikingLearningSimulation(objective.statisticians[0].target, tmax, dt, **args_juvenile)
res_juvenile = sim_juvenile.run(10)
stats_juvenile = calculate_statistics(collect_isi([_['student_spike'] for _ in res_juvenile],
        tmax=tmax)/1000.0)

sim_adult = SpikingLearningSimulation(objective.statisticians[1].target, tmax, dt, **args_adult)
res_adult = sim_adult.run(10)
stats_adult = calculate_statistics(collect_isi([_['student_spike'] for _ in res_adult],
        tmax=tmax)/1000.0)

In [ ]:
show_repetition_pattern([_['student_spike'] for _ in res_juvenile], idx=range(10), ms=2)
plt.xlim(0, tmax);
plt.title('Juvenile bird');

In [ ]:
show_repetition_pattern([_['student_spike'] for _ in res_adult], idx=range(10), ms=2)
plt.xlim(0, tmax);
plt.title('Adult bird');

Make figures

Violin plots


In [ ]:
# load the default parameters
with open('default_params.pkl', 'rb') as inp:
    default_params = pickle.load(inp)

In [ ]:
def tracker_generator(simulator, i, n):
    """ Generate some trackers. """
    res = {}
    res['student_spike'] = simulation.EventMonitor(simulator.student)

    return res

In [ ]:
actual_params = dict(default_params)
actual_params['tracker_generator'] = tracker_generator
actual_params.pop('target')
#sim = SpikingLearningSimulation(**actual_params)
#res = sim.run(10)
t0 = time.time()
sim = SimulationStatistician(n_reps=50, **actual_params)
res = sim()
t1 = time.time()
print("Simuation took {:.2f} seconds.".format(t1 - t0))

In [ ]:
epochs = [(i, i+25) for i in xrange(40, 160, 25)]
short_title_map = {
    'burst_mean_length': 'Average burst length (s)',
    'burst_rate': 'Burst frequency (Hz)',
    'cv': 'CV of ISI',
    'fano_factor': 'Fano factor (s)', # XXX what's the right normalization to get this to be dimensionless?
    'firing_rate': 'Average firing rate (Hz)',
    'firing_rate_bursts': 'Firing rate during bursts (Hz)',
    'isi_mean': 'Average ISI (s)',
    'isi_std': 'Standard deviation of ISI (s)',
    'isi_skew': 'Skewness of ISI'
}

In [ ]:
plt.figure(figsize=[6, 3.8])
draw_summary_stats_violins(epochs, selected_stats, selected_ages,
                           simstats=res, title_map=short_title_map,
                           order=['firing_rate', 'cv', 'isi_skew',
                                  'burst_rate', 'burst_mean_length',
                                  'firing_rate_bursts'])
plt.tight_layout()

safe_save_fig('figs/spiking_matching_violins', png=False)

In [ ]:
actual_params_adult = actual_params
actual_params_adult['cs_weights_fraction'] = 0.15157340812668241
actual_params_adult['cs_weights_params'] = (-2.7101150158852856, 0.80989719997502518)
t0 = time.time()
sim_adult = SimulationStatistician(n_reps=50, **actual_params_adult)
res_adult = sim_adult()
t1 = time.time()
print("Simuation took {:.2f} seconds.".format(t1 - t0))

In [ ]:
plt.figure(figsize=[6, 3.8])
draw_summary_stats_violins_double(epochs, selected_stats, selected_ages,
                           simstats1=res, simstats2=res_adult,
                           title_map=short_title_map,
                           order=['firing_rate', 'cv', 'isi_skew',
                                  'burst_rate', 'burst_mean_length',
                                  'firing_rate_bursts'],
                           simtext=['sim juvenile', 'sim adult'])
plt.tight_layout()

safe_save_fig('figs/spiking_matching_violins_double', png=False)

In [ ]: