In [ ]:
%matplotlib inline
import matplotlib as mpl
import matplotlib.ticker as mtick
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('white')
plt.rc('text', usetex=True)
plt.rc('font', family='serif', serif='cm')
plt.rcParams['figure.titlesize'] = 10
plt.rcParams['axes.labelsize'] = 8
plt.rcParams['axes.titlesize'] = 10
plt.rcParams['xtick.labelsize'] = 8
plt.rcParams['ytick.labelsize'] = 8
plt.rcParams['axes.labelpad'] = 3.0
from IPython.display import display, clear_output
from ipywidgets import FloatProgress
# comment out the next line if not working on a retina-display computer
import IPython
IPython.display.set_matplotlib_formats('retina')
In [ ]:
%load_ext autoreload
%autoreload 2
In [ ]:
import numpy as np
import scipy.io
import scipy.optimize as optimize
import scipy.stats as stats
import copy
import time
import os
import cPickle as pickle
import cma
In [ ]:
import simulation
from basic_defs import *
from helpers import *
In [ ]:
def lognormal_to_meanstdev(mu, sigma):
""" Find lognormal mean and standard deviation given mu and sigma
parameters.
Parameters
----------
mu, sigma
Mean and standard deviation of associated normal distribution.
Returns
-------
(mean, stdev)
Mean and standard deviation of lognormal distribution.
"""
return (np.exp(mu + sigma**2/2),
np.sqrt((np.exp(sigma**2) - 1)*np.exp(2*mu + sigma**2)))
In [ ]:
def find_bursts(isi, percentile_threshold_in=65, percentile_threshold_out=80, min_spikes=3,
isi_in=1.0/80.0, isi_out=1.0/40.0, return_idxs=False):
""" Find the bursts in a vector ISIs.
Bursts are returned as a vector of tuples, (start, end), giving times of burst start
and end using the first spike in the recording as a reference point.
Change `percentile_threshold_in` and `percentile_threshold_out` to set the threshold for
entering a burst and exiting one, respectively. Low percentile means low ISI threshold.
The values are in percent.
If `isi_in` and `isi_out` are not `None`, they are used instead.
A burst will only be recorded if it has at least `min_spikes` spikes.
If `return_idxs` is `True`, the function also returns the pairs of indices where bursts
start and end.
"""
isi = np.asarray(isi)
if isi_in is None:
isi_threshold_in = np.percentile(isi, percentile_threshold_in)
else:
isi_threshold_in = isi_in
if isi_out is None:
isi_threshold_out = np.percentile(isi, percentile_threshold_out)
else:
isi_threshold_out = isi_out
# places where bursts may start
starts = (isi < isi_threshold_in).nonzero()[0]
# ...and end
ends = (isi >= isi_threshold_out).nonzero()[0]
# find the first burst that starts, then find the first ending after that,
# then move to the next starting point, etc.
if len(starts) > 0:
crt_start = starts[0]
res0 = []
res = []
# res0 contains start and end *indices*
# transform them to times
times = np.hstack(([0], np.cumsum(isi)))
while True:
idx = (ends > crt_start).nonzero()[0]
if len(idx) == 0:
# we're done, burst ends at end of simulation
res0.append((crt_start, len(isi)))
break
crt_end = ends[idx[0]]
if crt_end - crt_start + 1 >= min_spikes:
res0.append((crt_start, crt_end))
res.append((times[crt_start], times[crt_end]))
idx = (starts > crt_end).nonzero()[0]
if len(idx) == 0:
# no more bursts
break
crt_start = starts[idx[0]]
else:
# no bursts
res = []
res0 = []
if not return_idxs:
return res
else:
return (res, res0)
In [ ]:
def calculate_statistics(isi, max_trustworthy=1.0):
""" Calculate several summary statistics about the spikes that generated the given vector
of inter-spike intervals (ISIs).
The `isi` vector is in assumed to be in seconds. ISIs that go above `max_trustworthy` are
ignored -- these are assumed to be inter-trial intervals.
"""
isi = isi[isi <= max_trustworthy]
t_max = float(np.sum(isi))
n_spikes = 1 + len(isi)
if t_max > 0:
firing_rate = n_spikes/t_max
else:
firing_rate = 0.0
if len(isi) > 0:
isi_mean = np.mean(isi)
isi_std = np.std(isi)
isi_skew = stats.skew(isi, bias=False)
if isi_mean > 0:
fano_factor = np.var(isi) / isi_mean
cv = isi_std / isi_mean
else:
fano_factor = 0.0
cv = 0.0
else:
isi_mean = 0.0
isi_std = 0.0
fano_factor = 0.0
cv = 0.0
bursts, bursts_idxs = find_bursts(isi, return_idxs=True)
if t_max > 0:
burst_rate = len(bursts)/t_max
else:
burst_rate = 0.0
burst_lengths = np.asarray([_[1] - _[0] for _ in bursts])
if len(burst_lengths) > 0:
burst_mean_length = np.mean(burst_lengths)
else:
burst_mean_length = 0.0
# number of spikes for each burst
burst_nspikes = np.asarray([_[1] - _[0] + 1 for _ in bursts_idxs])
burst_total_time = np.sum(burst_lengths)
if burst_total_time > 0:
firing_rate_bursts = np.sum(burst_nspikes)/burst_total_time
else:
firing_rate_bursts = 0.0
return {
'firing_rate': firing_rate,
'isi_mean': isi_mean,
'isi_std': isi_std,
'isi_skew': isi_skew,
'fano_factor': fano_factor,
'cv': cv,
'burst_rate': burst_rate,
'burst_mean_length': burst_mean_length,
'firing_rate_bursts': firing_rate_bursts
}
In [ ]:
class SimulationStatistician(object):
""" A class to simplify the job of calculating spiking statistics from a
simulation with a given set of parameters.
"""
def __init__(self, tmax, dt, **kwargs):
""" Set the parameters of the simulation.
Any extra arguments apart from `n_reps` get passed directly to
`SpikingLearningSimulation`.
"""
# the simulator class needs a target, but it's irrelevant for our purposes
self.tmax = tmax
self.dt = dt
self.target = np.zeros((1, int_r(self.tmax/self.dt)))
self.n_reps = kwargs.pop('n_reps', 1)
def tracker_generator(simulator, i, n):
""" Generate some trackers. """
res = {}
res['student_spike'] = simulation.EventMonitor(simulator.student)
return res
default_args = {
'n_student_per_output': 10,
'plasticity_learning_rate': 0,
'tutor_rule_gain': 0,
'tracker_generator': tracker_generator
}
default_args.update(kwargs)
self.default_args = default_args
def __call__(self, **kwargs):
""" Run the simulation and return the statistics.
If the keyword argument `combine_reps` is `False` (the default), the
function returns a list of dictionaries, one for each repetition of the
simulation. Otherwise, all the spikes get collected in one vector whose
statistics are returned.
Any other keyword arguments are passed to `SpikingLearningSimulation`.
"""
combine_reps = kwargs.pop('combine_reps', False)
self.args = dict(self.default_args)
self.args.update(kwargs)
self.last_simulator = SpikingLearningSimulation(self.target, self.tmax, self.dt,
**self.args)
self.last_res = self.last_simulator.run(self.n_reps)
if not combine_reps:
stats = []
for crt_res in self.last_res:
crt_isi = collect_isi([crt_res['student_spike']], tmax=self.tmax)
# division to convert from ms to s!
stats.append(calculate_statistics(crt_isi/1000.0))
return stats
else:
# division to convert from ms to s!
return calculate_statistics(collect_isi([_['student_spike'] for _ in self.last_res],
tmax=self.tmax)/1000.0)
In [ ]:
def average_stats(stats_list):
""" Take the average over the values of a list of dictionaries.
Values that cannot be averaged are not included in the final result.
"""
keys = stats_list[0].keys()
res = {}
n = float(len(stats_list))
for crt_key in keys:
crt_values = [_[crt_key] for _ in stats_list]
try:
crt_avg = sum(crt_values)/n
res[crt_key] = crt_avg
except:
pass
return res
In [ ]:
def draw_summary_stats_violins(epochs, all_stats, all_ages, simstats, title_map, order,
color=[0.831, 0.333, 0.000],
plot_min_age=43, plot_max_age=160, simpos='left'):
""" Draw a violin plot comparing different bird ages and simulation results statistics.
"""
n_plots = len(order)
nx = int(math.ceil(math.sqrt(n_plots)))
ny = int(math.ceil(float(n_plots)/nx))
epoch_stats = []
for crt_epoch in epochs:
epoch_stats.append([crt_stats for (crt_stats, crt_age) in zip(all_stats, all_ages)
if crt_age >= crt_epoch[0] and crt_age < crt_epoch[1]])
for (i, crtplot) in enumerate(order):
data = []
colors = []
names = []
plt.subplot(ny, nx, 1+i)
for (crt_epoch, crt_stats) in zip(epochs, epoch_stats):
values = np.asarray([x[crtplot] for x in crt_stats])
crt_dph = np.mean(crt_epoch)
data.append(values)
colors.append(color)
names.append('{}-{}'.format(*crt_epoch))
# now append the simulated values
simvalues = np.asarray([x[crtplot] for x in simstats])
simcolor = (0.2, 0.5, 1.0)
simtext = 'sim'
if simpos == 'right':
data.append(simvalues)
colors.append(simcolor)
names.append(simtext)
elif simpos == 'left':
data.insert(0, simvalues)
colors.insert(0, simcolor)
names.insert(0, simtext)
else:
raise Exception('Unknown simpos.')
# and draw!
sns.violinplot(data=data, palette=colors, scale='width')
plt.xticks(range(len(names)), names, rotation=45)
plt.ylim(0, plt.ylim()[1])
plt.grid(True)
plt.ylabel(title_map[crtplot])
# plt.gca().title.set_fontsize(10)
plt.gca().yaxis.label.set_fontsize(8)
for _ in plt.gca().get_xticklabels() + plt.gca().get_yticklabels():
_.set_fontsize(8)
In [ ]:
def draw_summary_stats_violins_double(epochs, all_stats, all_ages, simstats1, simstats2,
title_map, order,
color=[0.831, 0.333, 0.000],
plot_min_age=43, plot_max_age=160, simtext=['sim1', 'sim2']):
""" Draw a violin plot comparing different bird ages and simulation results statistics.
"""
n_plots = len(order)
nx = int(math.ceil(math.sqrt(n_plots)))
ny = int(math.ceil(float(n_plots)/nx))
epoch_stats = []
for crt_epoch in epochs:
epoch_stats.append([crt_stats for (crt_stats, crt_age) in zip(all_stats, all_ages)
if crt_age >= crt_epoch[0] and crt_age < crt_epoch[1]])
for (i, crtplot) in enumerate(order):
data = []
colors = []
names = []
plt.subplot(ny, nx, 1+i)
for (crt_epoch, crt_stats) in zip(epochs, epoch_stats):
values = np.asarray([x[crtplot] for x in crt_stats])
crt_dph = np.mean(crt_epoch)
data.append(values)
colors.append(color)
names.append('{}-{}'.format(*crt_epoch))
# now append the simulated values
simvalues1 = np.asarray([x[crtplot] for x in simstats1])
simvalues2 = np.asarray([x[crtplot] for x in simstats2])
simcolor = (0.2, 0.5, 1.0)
data.append(simvalues2)
colors.append(simcolor)
names.append(simtext[1])
data.insert(0, simvalues1)
colors.insert(0, simcolor)
names.insert(0, simtext[0])
# and draw!
sns.violinplot(data=data, palette=colors, scale='width')
plt.xticks(range(len(names)), names, rotation=45)
plt.ylim(0, plt.ylim()[1])
plt.grid(True)
plt.ylabel(title_map[crtplot])
plt.gca().yaxis.label.set_fontsize(8)
for _ in plt.gca().get_xticklabels() + plt.gca().get_yticklabels():
_.set_fontsize(8)
In [ ]:
spike_stats_raw = scipy.io.loadmat("data/allCellsNew.mat",
squeeze_me=True, chars_as_strings=True,
struct_as_record=False)
all_data = spike_stats_raw['allCellsNew']
isi_data = np.asarray([_.spikeI for _ in all_data])
ages = np.asarray([x.dph for x in all_data])
Select data below 200 dph, because the 200 dph birds are all directed singing.
In [ ]:
selected_mask = (ages < 200)
selected_isi_data = isi_data[selected_mask]
selected_ages = ages[selected_mask]
In [ ]:
all_stats = np.asarray([calculate_statistics(crt_isi) for crt_isi in isi_data])
selected_stats = all_stats[selected_mask]
In [ ]:
stats_title_map = {
'burst_mean_length': 'Mean duration of bursts (s)',
'burst_rate': 'Rate at which bursts occur (Hz)',
'cv': 'Coefficient of variation',
'fano_factor': 'Fano factor (s)',
'firing_rate': 'Average firing rate (Hz)',
'firing_rate_bursts': 'Average firing rate during bursts (Hz)',
'isi_mean': 'Mean inter-spike interval (s)',
'isi_std': 'Standard deviation of inter-spike intervals (s)',
'isi_skew': 'Skewness of inter-spike interval distribution'
}
In [ ]:
predictor_names = ['constant'] + stats_title_map.keys()
selected_summary_stats = np.column_stack(
[(x[name] if name != 'constant' else 1) for x in selected_stats]
for name in predictor_names)
lincoeffs, linresiduals, _, _ = np.linalg.lstsq(selected_summary_stats, np.asarray(selected_ages, dtype=float))
lincoeffs_mapped = {name: lincoeffs[i] for (i, name) in enumerate(predictor_names)}
guessed_selected_ages = np.dot(selected_summary_stats, lincoeffs)
print('Size of residuals is {:.2f} dph.'.format(np.sqrt(linresiduals[0]/(len(selected_ages)-1))))
In [ ]:
plt.scatter(selected_ages, guessed_selected_ages, alpha=0.6)
plt.scatter([selected_ages[68]], [guessed_selected_ages[68]], c='r') # juvenile
plt.scatter([selected_ages[122]], [guessed_selected_ages[122]], c='g') # adult
plt.xlabel("Real age (dph)")
plt.ylabel("Age inferred from linear model (dph)")
print "Correlation coefficient: {:.2f}.".format(np.corrcoef(selected_ages, guessed_selected_ages)[0, 1])
Find summary statistics that are very similar to one-another.
In [ ]:
sns.heatmap(np.corrcoef(selected_summary_stats.T), square=True, vmin=-1.0, vmax=1.0)
predictor_names_disp = [s.replace('_', ' ') for s in predictor_names]
plt.xticks(0.5 + np.arange(len(predictor_names_disp)), predictor_names_disp, rotation='vertical')
# XXX seaborn seems to flip the labels but not the coordinates!
plt.yticks(len(predictor_names_disp) - 0.5 - np.arange(len(predictor_names_disp)),
predictor_names_disp, rotation='horizontal');
In [ ]:
def pack_params(d, selection):
""" Convert a subset of the values in the dictionary `d` (selected
by the list `selection`) into a vector.
"""
res = []
for crt_name in selection:
# handle multi-component values
i = crt_name.find('::')
if i == -1:
res.append(d[crt_name])
else:
crt_value = d[crt_name]
if not hasattr(crt_value, '__len__') or len(crt_value) != int(crt_name[i+2:]):
raise Exception('Mismatch between name and length of value.')
res.extend(crt_value)
return np.asarray(res)
In [ ]:
def unpack_params(x, selection):
""" Convert the values from `pack_params` back to a dictionary. """
res = {}
k = 0
for crt_name in selection:
i = crt_name.find('::')
if i == -1:
res[crt_name] = x[k]
k += 1
else:
crt_n = int(crt_name[i+2:])
res[crt_name] = tuple(x[k:k+crt_n])
k += crt_n
return res
In [ ]:
def sanitize_keys(d):
""" 'Sanitize' the keys in the dictionary by removing any trailing '::<number>'
parts.
"""
res = {}
for key, value in d.items():
i = key.find('::')
if i < 0:
res[key] = value
else:
res[key[:i]] = value
return res
In [ ]:
def get_bounds_trafo(bounds_dict, selection):
""" Generate 'genotype' to 'phenotype' and reverse transformations
for use with CMA-ES.
"""
bounds = pack_params(bounds_dict, selection)
def bounds_trafo(x):
return bounds[:, 0] + 0.5*(bounds[:, 1] - bounds[:, 0])*(1.0 - np.cos(np.pi*x/10.0))
def bounds_rev_trafo(y):
return 10.0/np.pi*np.arccos(1.0 - 2.0*(y - bounds[:, 0])/(bounds[:, 1] - bounds[:, 0]))
return (bounds_trafo, bounds_rev_trafo)
In [ ]:
# choose statistics to optimize
optim_stats = ['firing_rate', 'cv', 'isi_skew', 'burst_rate', 'burst_mean_length', 'firing_rate_bursts']
In [ ]:
def stats_diff(stats1, stats2, optim_stats):
""" Calculate the difference between two stats. """
ratios = [stats1[key] / stats2[key] for key in optim_stats]
return np.linalg.norm(np.asarray(ratios) - 1.0)/np.sqrt(len(ratios))
In [ ]:
class ObjectiveFunction(object):
""" An objective function that calculates spiking statistics, for use
with CMA-ES.
"""
def __init__(self, *args, **kwargs):
""" Initialize the objective function.
Arguments
---------
target_stats: dictionary
Target values for the statistics.
optim_stats: list
Select the statistics that are optimized.
selection: list
Parameters that are being optimized.
pack_fct:
unpack_fct:
Functions that pack and unpack the parameters between vector
form and dictionary form. Their signatures are
pack_fct(dict, selection)
unpack_fct(x, selection)
All other arguments are passed directly to `SimulationStatistician`.
"""
self.target_stats = kwargs.pop('target_stats')
self.optim_stats = kwargs.pop('optim_stats')
self.selection = kwargs.pop('selection')
self.defaults = kwargs.pop('defaults')
self.pack_fct = kwargs.pop('pack_fct')
self.unpack_fct = kwargs.pop('unpack_fct')
self.statistician = SimulationStatistician(*args, **kwargs)
def __call__(self, x):
""" Calculate how different the statistics of the simulation given the
parameters `x` are from the target statistics.
Parameters
----------
x
Numeric array describing the parameters for the simulation.
Returns
-------
A single number representing a measure of how far the current simulation is
from the target, in terms of the summary statistics identified by `self.optim_stats`.
Given the current and target values of the statistics, `crt` and `tgt`, the
output is
`np.linalg.norm(crt/tgt - 1.0)/len(crt)`
"""
args = dict(self.defaults)
args.update(self.unpack_fct(x, self.selection))
args = sanitize_keys(args)
stats = self.statistician(**args)
if len(stats) > 1:
avg_stats = average_stats(stats)
else:
avg_stats = stats[0]
self.avg_stats = avg_stats
return stats_diff(self.avg_stats, self.target_stats, self.optim_stats)
In [ ]:
def split_args(args, i):
""" Keep the arguments that are common (i.e., have no '##' in their names)
or belong to simulation `i` (i.e., have '##' + str(i) in their names).
"""
crt_args = {}
for key, value in args.items():
# get rid of any length indicators
pound_j = key.find('::')
if pound_j >= 0:
key = key[:pound_j]
pound_i = key.find('##')
if pound_i < 0:
# a common argument
crt_args[key] = value
elif int(key[pound_i+2:]) == i:
# keep only the arguments meant for this simulation
crt_args[key[:pound_i]] = value
return crt_args
In [ ]:
def merge_dicts(d1, d2):
""" Merge the two dictionaries, keeping values from the second
one whenever there is overlap of keys.
"""
d = dict(d1)
d.update(d2)
return d
In [ ]:
class MultipleObjectiveFunction(object):
""" An objective function that optimizes the results of several
simulations instead of one.
The simulations share some parameters, while others are separate.
"""
def __init__(self, *args, **kwargs):
""" Initialize the objective function.
Arguments
---------
target_stats: list of dictionaries
Target values for the statistics.
optim_stats: list
Select the statistics that are optimized.
selection: list
Parameters that are being optimized. Parameters that have names
of the form 'param##0' or 'param##2::3' are specific to a given
simulation (index 0 or 2, respectively, for these examples).
Other parameters are common for all simulations. Note that the
indication for vector parameters ('::3' here) comes after the
simulation index.
defaults: dict
Default parameters that are used if they are not replaced by any
of the values from `selection`.
pack_fct:
unpack_fct:
Functions that pack and unpack parameters between vector
form and dictionary form. Their signatures are
pack_fct(dict, selection)
unpack_fct(x, selection)
All other arguments are passed directly to `SimulationStatistician`.
"""
self.target_stats = kwargs.pop('target_stats')
self.optim_stats = kwargs.pop('optim_stats')
self.selection = kwargs.pop('selection')
self.defaults = kwargs.pop('defaults')
self.pack_fct = kwargs.pop('pack_fct')
self.unpack_fct = kwargs.pop('unpack_fct')
self.n = len(self.target_stats)
self.statisticians = [SimulationStatistician(*args, **kwargs) for _ in xrange(self.n)]
def __call__(self, x):
""" Calculate how different the statistics of the simulations given the
parameters `x` are from the target statistics.
Parameters
----------
x
Numeric array describing the parameters for the simulations.
Returns
-------
A single number representing a measure of how far the current simulations are
from their targets, in terms of the summary statistics identified by
`self.optim_stats`.
Given the current and target values of the statistics, `crt[i]` and `tgt[i]`, the
output is
`np.mean(np.linalg.norm(crt/tgt - 1.0)/len(crt))`
"""
args = dict(self.defaults)
args.update(self.unpack_fct(x, self.selection))
args = sanitize_keys(args)
avg_stats = []
for i, crt_statistician in enumerate(self.statisticians):
crt_args = split_args(args, i)
crt_stats = crt_statistician(**crt_args)
if len(crt_stats) > 1:
crt_stats = average_stats(crt_stats)
else:
crt_stats = crt_stats[0]
avg_stats.append(crt_stats)
self.avg_stats = avg_stats
diffs = [stats_diff(crt, tgt, self.optim_stats) for (crt, tgt) in
zip(self.avg_stats, self.target_stats)]
return np.mean(diffs)
In [ ]:
# choose recordings to use for juvenile and adult bird, respectively
idx_juvenile = 68
idx_adult = 122
In [ ]:
# choose ranges for common parameters
optim_range = {
'conductor_rate_during_burst': (400, 800),
'student_vR': (-75.0, -65.0),
'student_v_th': (-55.0, -45.0),
'student_R': (100.0, 400.0),
'student_tau_m': (10.0, 25.0),
'student_tau_ampa': (3.0, 15.0),
'student_tau_nmda': (80.0, 120.0),
'student_g_inh': (0.1, 2.0),
# 'student_i_external': (-0.5, -0.05),
'student_tau_ref': (1.0, 2.0),
'cs_weights_params##0::2': ((-3.62, -3.52), (0.50, 0.58)),
'cs_weights_params##1::2': ((-2.87, -2.71), (0.71, 0.81)),
'cs_weights_fraction##0': (0.06, 0.5),
'cs_weights_fraction##1': (0.03, 0.5),
'ts_weights': (0.10, 0.20)
}
optim_params = ['conductor_rate_during_burst', 'student_v_th', 'student_R', 'student_vR',
'student_tau_m', 'student_tau_ampa',
'student_tau_nmda', 'student_tau_ref', #'student_i_external',
'student_g_inh',
'cs_weights_params##0::2', 'cs_weights_params##1::2',
'cs_weights_fraction##0', 'cs_weights_fraction##1', 'ts_weights']
In [ ]:
# choose a starting point (guess values)
guess_params = {
'conductor_rate_during_burst': 718.0,
'cs_weights_fraction##0': 0.4,
'cs_weights_fraction##1': 0.17,
'cs_weights_params##0::2': (-3.54, 0.54),
'cs_weights_params##1::2': (-2.72, 0.76),
'student_R': 400.0,
'student_g_inh': 1.48,
'student_tau_ampa': 4.8,
'student_tau_m': 24.5,
'student_tau_nmda': 94.4,
'student_tau_ref': 1.9,
'student_vR': -73.7,
'student_v_th': -53.1,
'ts_weights': 0.12
}
In [ ]:
# length and timestep of simulation
tmax = 600.0
dt = 0.2
objective = MultipleObjectiveFunction(tmax, dt,
target_stats=[selected_stats[idx_juvenile], selected_stats[idx_adult]],
optim_stats=optim_stats, selection=optim_params,
defaults=guess_params,
pack_fct=pack_params, unpack_fct=unpack_params,
n_conductor=300, n_student_per_output=100,
relaxation=400.0, relaxation_conductor=25.0,
plasticity_constrain_positive=False)
In [ ]:
bounds_trafos = get_bounds_trafo(optim_range, optim_params)
# run the optimization -- this can take a while
cma_res = cma.fmin(
objective, pack_params(guess_params, optim_params), sigma0=3.0,
options={'tolfun': 1e-3, 'tolfunhist': 1e-4, 'tolx': 1e-4,
'transformation': bounds_trafos,
'maxfevals': 1000, 'verb_disp': 1}
)
In [ ]:
best_params = unpack_params(cma_res[0], optim_params)
In [ ]:
# save parameters to file
best_params_full = merge_dicts(objective.statisticians[0].default_args,
sanitize_keys(merge_dicts(objective.defaults, best_params)))
# get rid of tracker/snapshot generators
best_params_full.pop('tracker_generator', None)
best_params_full.pop('snapshot_generator', None)
# include some extra data
best_params_full['tmax'] = tmax
best_params_full['dt'] = dt
# do the actual saving
imax = 1000
for i in xrange(imax):
# save to the first available slot, name ending in _0, _1, _2, ...
file_name = 'best_params_joint_{}.pkl'.format(i)
if os.path.exists(file_name):
continue
with open(file_name, 'wb') as out:
pickle.dump(best_params_full, out, 2)
break
# copy the resulting file to default_params.pkl to be used by the figure-making code
# below and by the simulation code in spiking_simulations.ipynb
In [ ]:
cma_res[-1].load()
plt.plot(np.min(cma_res[-1].f, axis=1))
plt.xlabel('error')
In [ ]:
args_juvenile = merge_dicts(objective.statisticians[0].default_args,
split_args(merge_dicts(guess_params, best_params), 0))
args_adult = merge_dicts(objective.statisticians[1].default_args,
split_args(merge_dicts(guess_params, best_params), 1))
In [ ]:
sim_juvenile = SpikingLearningSimulation(objective.statisticians[0].target, tmax, dt, **args_juvenile)
res_juvenile = sim_juvenile.run(10)
stats_juvenile = calculate_statistics(collect_isi([_['student_spike'] for _ in res_juvenile],
tmax=tmax)/1000.0)
sim_adult = SpikingLearningSimulation(objective.statisticians[1].target, tmax, dt, **args_adult)
res_adult = sim_adult.run(10)
stats_adult = calculate_statistics(collect_isi([_['student_spike'] for _ in res_adult],
tmax=tmax)/1000.0)
In [ ]:
show_repetition_pattern([_['student_spike'] for _ in res_juvenile], idx=range(10), ms=2)
plt.xlim(0, tmax);
plt.title('Juvenile bird');
In [ ]:
show_repetition_pattern([_['student_spike'] for _ in res_adult], idx=range(10), ms=2)
plt.xlim(0, tmax);
plt.title('Adult bird');
In [ ]:
# load the default parameters
with open('default_params.pkl', 'rb') as inp:
default_params = pickle.load(inp)
In [ ]:
def tracker_generator(simulator, i, n):
""" Generate some trackers. """
res = {}
res['student_spike'] = simulation.EventMonitor(simulator.student)
return res
In [ ]:
actual_params = dict(default_params)
actual_params['tracker_generator'] = tracker_generator
actual_params.pop('target')
#sim = SpikingLearningSimulation(**actual_params)
#res = sim.run(10)
t0 = time.time()
sim = SimulationStatistician(n_reps=50, **actual_params)
res = sim()
t1 = time.time()
print("Simuation took {:.2f} seconds.".format(t1 - t0))
In [ ]:
epochs = [(i, i+25) for i in xrange(40, 160, 25)]
short_title_map = {
'burst_mean_length': 'Average burst length (s)',
'burst_rate': 'Burst frequency (Hz)',
'cv': 'CV of ISI',
'fano_factor': 'Fano factor (s)', # XXX what's the right normalization to get this to be dimensionless?
'firing_rate': 'Average firing rate (Hz)',
'firing_rate_bursts': 'Firing rate during bursts (Hz)',
'isi_mean': 'Average ISI (s)',
'isi_std': 'Standard deviation of ISI (s)',
'isi_skew': 'Skewness of ISI'
}
In [ ]:
plt.figure(figsize=[6, 3.8])
draw_summary_stats_violins(epochs, selected_stats, selected_ages,
simstats=res, title_map=short_title_map,
order=['firing_rate', 'cv', 'isi_skew',
'burst_rate', 'burst_mean_length',
'firing_rate_bursts'])
plt.tight_layout()
safe_save_fig('figs/spiking_matching_violins', png=False)
In [ ]:
actual_params_adult = actual_params
actual_params_adult['cs_weights_fraction'] = 0.15157340812668241
actual_params_adult['cs_weights_params'] = (-2.7101150158852856, 0.80989719997502518)
t0 = time.time()
sim_adult = SimulationStatistician(n_reps=50, **actual_params_adult)
res_adult = sim_adult()
t1 = time.time()
print("Simuation took {:.2f} seconds.".format(t1 - t0))
In [ ]:
plt.figure(figsize=[6, 3.8])
draw_summary_stats_violins_double(epochs, selected_stats, selected_ages,
simstats1=res, simstats2=res_adult,
title_map=short_title_map,
order=['firing_rate', 'cv', 'isi_skew',
'burst_rate', 'burst_mean_length',
'firing_rate_bursts'],
simtext=['sim juvenile', 'sim adult'])
plt.tight_layout()
safe_save_fig('figs/spiking_matching_violins_double', png=False)
In [ ]: