Testing tutor-student matching with spiking simulations


In [ ]:
%matplotlib inline
import matplotlib as mpl
import matplotlib.ticker as mtick
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('white')

plt.rc('text', usetex=True)
plt.rc('font', family='serif', serif='cm')

plt.rcParams['figure.titlesize'] = 10
plt.rcParams['axes.labelsize'] = 8
plt.rcParams['axes.titlesize'] = 10
plt.rcParams['xtick.labelsize'] = 8
plt.rcParams['ytick.labelsize'] = 8
plt.rcParams['axes.labelpad'] = 3.0

from IPython.display import display, clear_output
from ipywidgets import FloatProgress

# comment out the next line if not working on a retina-display computer
import IPython
IPython.display.set_matplotlib_formats('retina')

In [ ]:
import numpy as np
import copy
import time
import os
import cPickle as pickle

In [ ]:
import simulation
from basic_defs import *
from helpers import *

Define target motor programs


In [ ]:
tmax = 600.0                  # duration of motor program (ms)
dt = 0.2                      # simulation timestep (ms)

nsteps = int(tmax/dt)
times = np.arange(0, tmax, dt)

In [ ]:
# add some noise, but keep things reproducible
np.random.seed(0)
target_complex = 100.0*np.vstack((
            np.convolve(np.sin(times/100 + 0.1*np.random.randn(len(times)))**6 +
                      np.cos(times/150 + 0.2*np.random.randn(len(times)) + np.random.randn())**4,
                  np.exp(-0.5*np.linspace(-3.0, 3.0, 200)**2)/np.sqrt(2*np.pi)/80, mode='same'),
            np.convolve(np.sin(times/110 + 0.15*np.random.randn(len(times)) + np.pi/3)**6 +
                      np.cos(times/100 + 0.2*np.random.randn(len(times)) + np.random.randn())**4,
                  np.exp(-0.5*np.linspace(-3.0, 3.0, 200)**2)/np.sqrt(2*np.pi)/80, mode='same'),
        ))

In [ ]:
# or start with something simple: constant target
target_const = np.vstack((70.0*np.ones(len(times)), 50.0*np.ones(len(times))))

In [ ]:
# or something simple but not trivial: steps
target_piece = np.vstack((
    np.hstack((20.0*np.ones(len(times)/2), 100.0*np.ones(len(times)/2))),
    np.hstack((60.0*np.ones(len(times)/2), 30.0*np.ones(len(times)/2)))))

In [ ]:
targets = {'complex': target_complex, 'piece': target_piece, 'constant': target_const}

Choose target


In [ ]:
# choose one target
target_choice = 'complex'
#target_choice = 'constant'

In [ ]:
target = copy.copy(targets[target_choice])

# make sure the target smoothly goes to zero at the edges
# this is to match the spiking simulation, which needs some time to ramp
# up in the beginning and time to ramp down at the end
edge_duration = 100.0 # ms
edge_len = int(edge_duration/dt)
tapering_x = np.linspace(0.0, 1.0, edge_len, endpoint=False)
tapering = (3 - 2*tapering_x)*tapering_x**2
target[:, :edge_len] *= tapering
target[:, -edge_len:] *= tapering[::-1]

General definitions


In [ ]:
class ProgressBar(object):
    
    """ A callable that displays a widget progress bar and can also make a plot showing
    the learning trace.
    """
    
    def __init__(self, simulator, show_graph=True, graph_step=20, max_error=1000):
        self.t0 = None
        self.float = None
        self.show_graph = show_graph
        self.graph_step = graph_step
        self.simulator = simulator
        self.max_error = max_error
        self.print_last = True
    
    def __call__(self, i, n):
        t = time.time()
        if self.t0 is None:
            self.t0 = t
        t_diff = t - self.t0
        
        current_res = self.simulator._current_res
        
        text = 'step: {} ; time elapsed: {:.1f}s'.format(i, t_diff)
        
        if len(current_res) > 0:
            last_error = current_res[-1]['average_error']
            if last_error <= self.max_error:
                text += ' ; last error: {:.2f}'.format(last_error)
            else:
                text += ' ; last error: very large'
        if self.float is None:
            self.float = FloatProgress(min=0, max=100)
            display(self.float)
        else:
            percentage = min(round(i*100.0/n), 100)
            self.float.value = percentage
            
        self.float.description = text
            
        if self.show_graph and (i % self.graph_step == 0 or i == n):
            crt_res = [_['average_error'] for _ in current_res]
            plt.plot(range(len(crt_res)), crt_res, '.-k')
            plt.xlim(0, n-1)
            plt.xlabel('repetition')
            plt.ylabel('error')

            if len(crt_res) > 0:
                if i < 100:
                    plt.ylim(np.min(crt_res) - 0.1, np.max(crt_res) + 0.1)
                else:
                    plt.ylim(0, np.max(crt_res))
            else:
                plt.ylim(0, 1)

            clear_output(wait=True)
            if i < n:
                display(plt.gcf())
        
        if i == n:
            self.float.close()
            if self.print_last:
                print(text)

In [ ]:
def tracker_generator(simulator, i, n):
    """ Generate some trackers. """
    res = {}
    if i % 10 == 0:
        res['tutor'] = simulation.StateMonitor(simulator.tutor, 'out')
        res['student_spike'] = simulation.EventMonitor(simulator.student)
        res['motor'] = simulation.StateMonitor(simulator.motor, 'out')
    
    return res

def snapshot_generator_pre(simulator, i, n):
    """ Generate some pre-run snapshots. """
    res = {}
    if i % 50 == 0:
        res['weights'] = np.copy(simulator.conductor_synapses.W)
    
    return res

Create default parameters file


In [ ]:
# start with the best parameters from the experiment matcher
best_params_file = 'best_params_joint.pkl'
with open(best_params_file, 'rb') as inp:
    best_params_full = pickle.load(inp)

# keep the values for the juvenile bird
default_params = {}
for key, value in best_params_full.items():
    pound_i = key.find('##')
    if pound_i >= 0:
        if int(key[pound_i+2:]) > 0:
            # this is not for the juvenile
            continue
        key = key[:pound_i]
    
    default_params[key] = value

# add the target, and make sure we have the right tmax and dt
default_params['target'] = target
default_params['tmax'] = tmax
default_params['dt'] = dt

# the number of student neuros per output doesn't have to be so high
default_params['n_student_per_output'] = 40

# the best_params file also has no learning, so let's set better defaults there
default_params['plasticity_learning_rate'] = 0.6e-9
default_params['plasticity_constrain_positive'] = True
default_params['plasticity_taus'] = (80.0, 40.0)
default_params['plasticity_params'] = (1.0, 0.0)

default_params.pop('tutor_rule_gain', None)
default_params['tutor_rule_gain_per_student'] = 0.5
default_params['tutor_rule_tau'] = 0.0

# the best_params also didn't care about the controller -- let's se tthat
default_params['controller_mode'] = 'sum'
default_params['controller_scale'] = 0.5

# save!
defaults_name = 'default_params.pkl'
if not os.path.exists(defaults_name):
    with open(defaults_name, 'wb') as out:
        pickle.dump(default_params, out, 2)
else:
    raise Exception('File exists!')

Generate data for figures

Learning curve (blackbox)


In [ ]:
def tracker_generator(simulator, i, n):
    """ Generate some trackers. """
    res = {}
    res['motor'] = simulation.StateMonitor(simulator.motor, 'out')
    res['student_spike'] = simulation.EventMonitor(simulator.student)
    if i % 10 == 0:
        res['tutor'] = simulation.StateMonitor(simulator.tutor, 'out')
        res['conductor'] = simulation.StateMonitor(simulator.conductor, 'out')
        res['student'] = simulation.StateMonitor(simulator.student, 'out')
        res['conductor_spike'] = simulation.EventMonitor(simulator.conductor)
    
    return res

def snapshot_generator_pre(simulator, i, n):
    """ Generate some pre-run snapshots. """
    res = {}
    if i % 10 == 0:
        res['weights'] = np.copy(simulator.conductor_synapses.W)
    
    return res

In [ ]:
# load the default parameters
with open('default_params.pkl', 'rb') as inp:
    default_params = pickle.load(inp)

In [ ]:
# keep things arbitrary but reproducible
np.random.seed(12314)

actual_params = dict(default_params)
actual_params['plasticity_params'] = (1.0, 0.0)
actual_params['tutor_rule_tau'] = 80.0
actual_params['progress_indicator'] = ProgressBar
actual_params['tracker_generator'] = tracker_generator
actual_params['snapshot_generator'] = snapshot_generator_pre
simulator = SpikingLearningSimulation(**actual_params)

In [ ]:
res = simulator.run(200)

In [ ]:
file_name = 'save/spiking_example.pkl'
if not os.path.exists(file_name):
    with open(file_name, 'wb') as out:
        pickle.dump({'params': actual_params, 'res': res}, out, 2)
else:
    raise Exception('File exists!')

In [ ]:
plot_evolution(res, target, dt)

In [ ]:
show_repetition_pattern([_['student_spike'] for _ in res[-10:]], idx=range(10), ms=2.0)
plt.xlim(0, tmax)
crt_times0 = np.asarray(res[-1]['student_spike'].t)
crt_times = crt_times0[crt_times0 < tmax]
print('Average firing rate {:.2f} Hz.').format(len(crt_times)*1000.0/tmax/simulator.student.N)

Learning curve blackbox, realistic target


In [ ]:
# add some noise, but keep things reproducible
np.random.seed(0)

smoothlen = 400

realTarget1 = np.zeros(len(times))
realTarget1[int_r(50.0/dt):int_r(65.0/dt)] = 90.0
realTarget1[int_r(65.0/dt):int_r(75.0/dt)] = 20.0
realTarget1[int_r(75.0/dt):int_r(100.0/dt)] = 90.0

realTarget1[int_r(125.0/dt):int_r(150.0/dt)] = 80.0
realTarget1[int_r(150.0/dt):int_r(160.0/dt)] = 40.0

realTarget1[int_r(250.0/dt):int_r(280.0/dt)] = 80.0
realTarget1[int_r(305.0/dt):int_r(320.0/dt)] = 70.0

realTarget1[int_r(350.0/dt):int_r(360.0/dt)] = 90.0

realTarget1[int_r(410.0/dt):int_r(450.0/dt)] = 100.0
realTarget1[int_r(450.0/dt):int_r(470.0/dt)] = 60.0

realTarget1[int_r(500.0/dt):int_r(540.0/dt)] = 80.0

realTarget1 = np.convolve(realTarget1,
                          np.exp(-0.5*np.linspace(-3.0, 3.0, smoothlen)**2)/np.sqrt(2*np.pi)/80,
                          mode='same')

realTarget2 = np.zeros(len(times))
realTarget2[int_r(60.0/dt):int_r(75.0/dt)] = 90.0
realTarget2[int_r(100.0/dt):int_r(115.0/dt)] = 100.0

realTarget2[int_r(265.0/dt):int_r(290.0/dt)] = 90.0
realTarget2[int_r(320.0/dt):int_r(330.0/dt)] = 40.0
realTarget2[int_r(330.0/dt):int_r(365.0/dt)] = 100.0
realTarget2[int_r(385.0/dt):int_r(400.0/dt)] = 90.0

realTarget2[int_r(415.0/dt):int_r(450.0/dt)] = 80.0
realTarget2[int_r(470.0/dt):int_r(480.0/dt)] = 80.0

realTarget2[int_r(520.0/dt):int_r(540.0/dt)] = 90.0

realTarget2 = np.convolve(realTarget2,
                          np.exp(-0.5*np.linspace(-3.0, 3.0, smoothlen)**2)/np.sqrt(2*np.pi)/80,
                          mode='same')

realTarget3 = np.zeros(len(times))
realTarget3[int_r(70.0/dt):int_r(100.0/dt)] = 100.0

realTarget3[int_r(160.0/dt):int_r(180.0/dt)] = 100.0

realTarget3[int_r(260.0/dt):int_r(275.0/dt)] = 100.0

realTarget3[int_r(285.0/dt):int_r(310.0/dt)] = 100.0

realTarget3[int_r(340.0/dt):int_r(360.0/dt)] = 100.0

realTarget3[int_r(435.0/dt):int_r(470.0/dt)] = 90.0

realTarget3[int_r(530.0/dt):int_r(540.0/dt)] = 80.0

realTarget3 = np.convolve(realTarget3,
                          np.exp(-0.5*np.linspace(-3.0, 3.0, smoothlen)**2)/np.sqrt(2*np.pi)/80,
                          mode='same')

realTarget4 = np.zeros(len(times))
realTarget4[int_r(50.0/dt):int_r(65.0/dt)] = 30.0
realTarget4[int_r(65.0/dt):int_r(85.0/dt)] = 100.0

realTarget4[int_r(135.0/dt):int_r(150.0/dt)] = 90.0

realTarget4[int_r(285.0/dt):int_r(300.0/dt)] = 90.0

realTarget4[int_r(385.0/dt):int_r(405.0/dt)] = 60.0

realTarget4[int_r(430.0/dt):int_r(450.0/dt)] = 100.0

realTarget4[int_r(525.0/dt):int_r(540.0/dt)] = 70.0

realTarget4 = np.convolve(realTarget4,
                          np.exp(-0.5*np.linspace(-3.0, 3.0, smoothlen)**2)/np.sqrt(2*np.pi)/80,
                          mode='same')

realTarget5 = np.zeros(len(times))
realTarget5[int_r(75.0/dt):int_r(85.0/dt)] = 20.0
realTarget5[int_r(115.0/dt):int_r(130.0/dt)] = 60.0

realTarget5[int_r(180.0/dt):int_r(200.0/dt)] = 90.0
realTarget5[int_r(265.0/dt):int_r(290.0/dt)] = 100.0
realTarget5[int_r(325.0/dt):int_r(350.0/dt)] = 70.0

realTarget5[int_r(410.0/dt):int_r(420.0/dt)] = 80.0
realTarget5[int_r(440.0/dt):int_r(455.0/dt)] = 70.0

realTarget5[int_r(535.0/dt):int_r(545.0/dt)] = 20.0

realTarget5 = np.convolve(realTarget5,
                          np.exp(-0.5*np.linspace(-3.0, 3.0, smoothlen)**2)/np.sqrt(2*np.pi)/80,
                          mode='same')

In [ ]:
realTarget = np.vstack((realTarget1, realTarget2, realTarget3, realTarget4, realTarget5))

In [ ]:
def tracker_generator(simulator, i, n):
    """ Generate some trackers. """
    res = {}
    res['motor'] = simulation.StateMonitor(simulator.motor, 'out')
    res['student_spike'] = simulation.EventMonitor(simulator.student)
    if i % 10 == 0:
        res['tutor'] = simulation.StateMonitor(simulator.tutor, 'out')
        res['conductor'] = simulation.StateMonitor(simulator.conductor, 'out')
        res['student'] = simulation.StateMonitor(simulator.student, 'out')
        res['conductor_spike'] = simulation.EventMonitor(simulator.conductor)
    
    return res

def snapshot_generator_pre(simulator, i, n):
    """ Generate some pre-run snapshots. """
    res = {}
    if i % 10 == 0:
        res['weights'] = np.copy(simulator.conductor_synapses.W)
    
    return res

In [ ]:
# load the default parameters
with open('default_params.pkl', 'rb') as inp:
    default_params = pickle.load(inp)

In [ ]:
# keep things arbitrary but reproducible
np.random.seed(12314)

actual_params = dict(default_params)

actual_params['target'] = realTarget

actual_params['plasticity_params'] = (1.0, 0.0)
actual_params['tutor_rule_tau'] = 80.0
actual_params['progress_indicator'] = ProgressBar
actual_params['tracker_generator'] = tracker_generator
actual_params['snapshot_generator'] = snapshot_generator_pre

actual_params['tutor_rule_gain_per_student'] = 1.0
actual_params['plasticity_learning_rate'] = 1e-9

#actual_params['n_student_per_output'] = 10
#actual_params['controller_scale'] = 0.5*4

simulator = SpikingLearningSimulation(**actual_params)

In [ ]:
res = simulator.run(600)

In [ ]:
file_name = 'save/spiking_example_realistic.pkl'
if not os.path.exists(file_name):
    with open(file_name, 'wb') as out:
        pickle.dump({'params': actual_params, 'res': res}, out, 2)
else:
    raise Exception('File exists!')

In [ ]:
plot_evolution(res, realTarget, dt)

Learning curve (blackbox), constant inhibition


In [ ]:
def tracker_generator(simulator, i, n):
    """ Generate some trackers. """
    res = {}
    res['motor'] = simulation.StateMonitor(simulator.motor, 'out')
    res['student_spike'] = simulation.EventMonitor(simulator.student)
    if i % 10 == 0:
        res['tutor'] = simulation.StateMonitor(simulator.tutor, 'out')
        res['conductor'] = simulation.StateMonitor(simulator.conductor, 'out')
        res['student'] = simulation.StateMonitor(simulator.student, 'out')
        res['conductor_spike'] = simulation.EventMonitor(simulator.conductor)
    
    return res

def snapshot_generator_pre(simulator, i, n):
    """ Generate some pre-run snapshots. """
    res = {}
    if i % 10 == 0:
        res['weights'] = np.copy(simulator.conductor_synapses.W)
    
    return res

In [ ]:
# load the default parameters
with open('default_params.pkl', 'rb') as inp:
    default_params = pickle.load(inp)

In [ ]:
# keep things arbitrary but reproducible
np.random.seed(12314)

actual_params = dict(default_params)
actual_params['plasticity_params'] = (1.0, 0.0)
actual_params['tutor_rule_tau'] = 80.0
actual_params['progress_indicator'] = ProgressBar
actual_params['tracker_generator'] = tracker_generator
actual_params['snapshot_generator'] = snapshot_generator_pre
actual_params['student_g_inh'] = 0
actual_params['student_i_external'] = -0.23
simulator = SpikingLearningSimulation(**actual_params)

In [ ]:
res = simulator.run(200)

In [ ]:
file_name = 'save/spiking_example_const_inh.pkl'
if not os.path.exists(file_name):
    with open(file_name, 'wb') as out:
        pickle.dump({'params': actual_params, 'res': res}, out, 2)
else:
    raise Exception('File exists!')

In [ ]:
plot_evolution(res, target, dt)

Reinforcement example (0 ms)


In [ ]:
def tracker_generator(simulator, i, n):
    """ Generate some trackers. """
    res = {}
    if i % 10 == 0:
        res['tutor'] = simulation.StateMonitor(simulator.tutor, 'out')
        res['student_spike'] = simulation.EventMonitor(simulator.student)
        res['motor'] = simulation.StateMonitor(simulator.motor, 'out')
    
    return res

def snapshot_generator_pre(simulator, i, n):
    """ Generate some pre-run snapshots. """
    res = {}
    if i % 50 == 0:
        res['weights'] = np.copy(simulator.conductor_synapses.W)
    
    return res

In [ ]:
# load the default parameters
with open('default_params.pkl', 'rb') as inp:
    default_params = pickle.load(inp)

In [ ]:
# keep things arbitrary but reproducible
np.random.seed(212312)

args = dict(default_params)

args['relaxation'] = 200.0
args['relaxation_conductor'] = 200.0

args['tutor_tau_out'] = 40.0
args['tutor_rule_type'] = 'reinforcement'
args['tutor_rule_learning_rate'] = 0.004
args['tutor_rule_compress_rates'] = True
args['tutor_rule_relaxation'] = None
args['tutor_rule_tau'] = 0.0

args['plasticity_params'] = (1.0, 0.0)
args['plasticity_constrain_positive'] = True
args['plasticity_learning_rate'] = 7e-10


args_actual = dict(args)
args_actual['tracker_generator'] = tracker_generator
args_actual['snapshot_generator'] = snapshot_generator_pre
args_actual['progress_indicator'] = ProgressBar

simulator = SpikingLearningSimulation(**args_actual)

In [ ]:
res = simulator.run(10000)

In [ ]:
# save!
file_name = 'save/reinforcement_example_0ms.pkl'
if not os.path.exists(file_name):
    with open(file_name, 'wb') as out:
        pickle.dump({'res': res, 'args': args}, out, 2)
else:
    raise Exception('File exists!')

Reinforcement example (80 ms)


In [ ]:
def tracker_generator(simulator, i, n):
    """ Generate some trackers. """
    res = {}
    if i % 10 == 0:
        res['tutor'] = simulation.StateMonitor(simulator.tutor, 'out')
        res['student_spike'] = simulation.EventMonitor(simulator.student)
        res['motor'] = simulation.StateMonitor(simulator.motor, 'out')
    
    return res

def snapshot_generator_pre(simulator, i, n):
    """ Generate some pre-run snapshots. """
    res = {}
    if i % 50 == 0:
        res['weights'] = np.copy(simulator.conductor_synapses.W)
    
    return res

In [ ]:
# keep things arbitrary but reproducible
np.random.seed(212312)

args = dict(
        target=target, tmax=tmax, dt=dt,
        n_conductor=300, n_student_per_output=40,
        relaxation=200.0, relaxation_conductor=200.0, # XXX different from blackbox!
        conductor_rate_during_burst=769.7,
        controller_mode='sum',
        controller_scale=0.5,
        tutor_tau_out=40.0,
        tutor_rule_type='reinforcement',
        tutor_rule_learning_rate=0.006,
        tutor_rule_compress_rates=True,
        tutor_rule_tau=80.0,
        tutor_rule_relaxation=None,                   # XXX different from blackbox!
        cs_weights_fraction=0.488, ts_weights=0.100,
        plasticity_constrain_positive=True,
        plasticity_learning_rate=6e-10,
        plasticity_taus=(80.0, 40.0),
        plasticity_params=(1.0, 0.0),
        student_R=383.4, student_g_inh=1.406,
        student_tau_ampa=5.390, student_tau_nmda=81.92,
        student_tau_m=20.31, student_tau_ref=1.703,
        student_vR=-74.39, student_v_th=-45.47
    )

args_actual = dict(args)
args_actual['tracker_generator'] = tracker_generator
args_actual['snapshot_generator'] = snapshot_generator_pre
args_actual['progress_indicator'] = ProgressBar

simulator = SpikingLearningSimulation(**args_actual)

In [ ]:
res = simulator.run(16000)

In [ ]:
# save!
file_name = 'save/reinforcement_example_80ms.pkl'
if not os.path.exists(file_name):
    with open(file_name, 'wb') as out:
        pickle.dump({'res': res, 'args': args}, out, 2)
else:
    raise Exception('File exists!')

In [ ]:
plot_evolution(res, target, dt)

Reinforcement example (alpha=10, beta=9, tau=440 ms)


In [ ]:
def tracker_generator(simulator, i, n):
    """ Generate some trackers. """
    res = {}
    if i % 10 == 0:
        res['tutor'] = simulation.StateMonitor(simulator.tutor, 'out')
        res['student_spike'] = simulation.EventMonitor(simulator.student)
        res['motor'] = simulation.StateMonitor(simulator.motor, 'out')
    
    return res

def snapshot_generator_pre(simulator, i, n):
    """ Generate some pre-run snapshots. """
    res = {}
    if i % 50 == 0:
        res['weights'] = np.copy(simulator.conductor_synapses.W)
    
    return res

In [ ]:
# keep things arbitrary but reproducible
np.random.seed(12234)

args = dict(default_params)

args['relaxation'] = 200.0
args['relaxation_conductor'] = 200.0

args['tutor_tau_out'] = 40.0
args['tutor_rule_type'] = 'reinforcement'
args['tutor_rule_learning_rate'] = 0.004
args['tutor_rule_compress_rates'] = True
args['tutor_rule_relaxation'] = None
args['tutor_rule_tau'] = 440.0

args['plasticity_params'] = (10.0, 9.0)
args['plasticity_constrain_positive'] = True
args['plasticity_learning_rate'] = 7e-10

args_actual = dict(args)
args_actual['tracker_generator'] = tracker_generator
args_actual['snapshot_generator'] = snapshot_generator_pre
args_actual['progress_indicator'] = ProgressBar

simulator = SpikingLearningSimulation(**args_actual)

In [ ]:
res = simulator.run(10000)

In [ ]:
# save!
file_name = 'save/reinforcement_example_a10b9_440ms.pkl'
if not os.path.exists(file_name):
    with open(file_name, 'wb') as out:
        pickle.dump({'res': res, 'args': args}, out, 2)
else:
    raise Exception('File exists!')

In [ ]:
plot_evolution(res, target, dt)

Make figures

Tutor-student mismatch heatmap and convergence map -- blackbox spiking

The data for this needs to be generated using the summarize.py script from the results of the run_tscale_batch.py script, which is designed to run on a cluster.


In [ ]:
file_name = 'spike_out/songspike_tscale_batch_8.8.160525.1530_summary.pkl'
with open(file_name, 'rb') as inp:
    mismatch_data = pickle.load(inp)

In [ ]:
make_heatmap_plot(mismatch_data['res_array'], args_matrix=mismatch_data['args_array'],
                  vmin=1.0, vmax=10, sim_idx=250)
safe_save_fig('figs/spiking_mismatch_heatmap_sum_log_8', png=False)

In [ ]:
make_convergence_map(mismatch_data['res_array'], args_matrix=mismatch_data['args_array'],
                     max_steps=250)
safe_save_fig('figs/spiking_mismatch_convmap_sum_log_8', png=False)

Tutor-student bigger mismatch heatmap and convergence map -- blackbox spiking

The data for this needs to be generated using the summarize.py script from the results of the run_tscale_batch.py script, which is designed to run on a cluster.


In [ ]:
file_name = 'spike_out/songspike_tscale_batch_12.12.161122.1802_summary.pkl'
with open(file_name, 'rb') as inp:
    mismatch_data = pickle.load(inp)

In [ ]:
make_heatmap_plot(mismatch_data['res_array'], args_matrix=mismatch_data['args_array'],
                  vmin=0.5, vmax=10, sim_idx=999)
safe_save_fig('figs/spiking_mismatch_heatmap_sum_log_12', png=False)

In [ ]:
make_convergence_map(mismatch_data['res_array'], args_matrix=mismatch_data['args_array'],
                     max_steps=999)
safe_save_fig('figs/spiking_mismatch_convmap_sum_log_12', png=False)

In [ ]:
error_matrix = np.asarray([[_[-1] for _ in crt_res] for crt_res in mismatch_data['res_array']])
error_matrix[~np.isfinite(error_matrix)] = np.inf
tau_levels = np.asarray([_['tutor_rule_tau'] for _ in mismatch_data['args_array'][0]])

In [ ]:
plt.semilogx(tau_levels, np.diag(error_matrix), '.-k')

Tutor-student mismatch heatmap and convergence map -- reinforcement


In [ ]:
file_name = 'spike_out/song_reinf_tscale_batch_8.8.160607.1153_summary.pkl'
with open(file_name, 'rb') as inp:
    mismatch_data = pickle.load(inp)

In [ ]:
make_heatmap_plot(mismatch_data['res_array'], args_matrix=mismatch_data['args_array'],
                  vmin=1.0, vmax=10)
safe_save_fig('figs/reinforcement_mismatch_heatmap_sum_log_8', png=False)

In [ ]:
make_convergence_map(mismatch_data['res_array'], args_matrix=mismatch_data['args_array'], max_error=12)
safe_save_fig('figs/reinforcement_mismatch_convmap_sum_log_8', png=False)

Spiking example learning curve and raster plots


In [ ]:
with open('save/spiking_example.pkl', 'rb') as inp:
    spiking_example_data = pickle.load(inp)

In [ ]:
fig = plt.figure(figsize=(3, 2))
axs = draw_convergence_plot(spiking_example_data['res'], plt.gca(), target_lw=2,
                        inset=True, inset_pos=[0.4, 0.4, 0.4, 0.4],
                        alpha=spiking_example_data['params']['plasticity_params'][0],
                        beta=spiking_example_data['params']['plasticity_params'][1],
                        tau_tutor=spiking_example_data['params']['tutor_rule_tau'],
                        target=spiking_example_data['params']['target'])

axs[0].set_ylim(0, 15);


safe_save_fig('figs/spiking_example_learning_curve', png=False)

In [ ]:
plt.figure(figsize=(3, 1))
crt_res = spiking_example_data['res'][:5]
show_repetition_pattern([_['student_spike'] for _ in crt_res], idx=[1, 47, 23, 65, 78], ms=1.0)
plt.gca().spines['top'].set_color('none')
plt.gca().spines['right'].set_color('none')
plt.gca().spines['left'].set_color('none')
plt.yticks([])
plt.ylabel('')
plt.xlim(0, tmax);

print('Firing rate {:.2f} Hz.').format(get_firing_rate(crt_res[-1],
                                                       spiking_example_data['params']['tmax']))

safe_save_fig('figs/spiking_simraster_juvenile', png=False)

In [ ]:
plt.figure(figsize=(3, 1))
crt_res = spiking_example_data['res'][-5:]
show_repetition_pattern([_['student_spike'] for _ in crt_res], idx=[1, 47, 23, 65, 78], ms=1.0)
plt.gca().spines['top'].set_color('none')
plt.gca().spines['right'].set_color('none')
plt.gca().spines['left'].set_color('none')
plt.yticks([])
plt.ylabel('')
plt.xlim(0, tmax);

print('Firing rate {:.2f} Hz.').format(get_firing_rate(crt_res[-1],
                                                       spiking_example_data['params']['tmax']))

safe_save_fig('figs/spiking_simraster_adult', png=False)

Spiking example learning curve and raster plots, realistic target


In [ ]:
with open('save/spiking_example_realistic.pkl', 'rb') as inp:
    spiking_example_data = pickle.load(inp)

In [ ]:
fig = plt.figure(figsize=(3, 2))
axs = draw_convergence_plot(spiking_example_data['res'], plt.gca(), target_lw=2,
                        inset=True, inset_pos=[0.4, 0.45, 0.4, 0.4],
                        legend_pos=(0.7, 1.1),
                        alpha=spiking_example_data['params']['plasticity_params'][0],
                        beta=spiking_example_data['params']['plasticity_params'][1],
                        tau_tutor=spiking_example_data['params']['tutor_rule_tau'],
                        target=spiking_example_data['params']['target'])

axs[0].set_ylim(0, 15);


safe_save_fig('figs/spiking_example_realistic_learning_curve', png=False)

In [ ]:
plt.figure(figsize=(3, 1))
crt_res = spiking_example_data['res'][:5]
show_repetition_pattern([_['student_spike'] for _ in crt_res], idx=[1, 47, 87, 123, 165])
plt.gca().spines['top'].set_color('none')
plt.gca().spines['right'].set_color('none')
plt.gca().spines['left'].set_color('none')
plt.yticks([])
plt.ylabel('')
plt.xlim(0, tmax);

print('Firing rate {:.2f} Hz.').format(get_firing_rate(crt_res[-1],
                                                       spiking_example_data['params']['tmax']))

safe_save_fig('figs/spiking_simraster_realistic_juvenile', png=False)

In [ ]:
plt.figure(figsize=(3, 1))
crt_res = spiking_example_data['res'][-5:]
show_repetition_pattern([_['student_spike'] for _ in crt_res], idx=[1, 47, 87, 123, 165])
plt.gca().spines['top'].set_color('none')
plt.gca().spines['right'].set_color('none')
plt.gca().spines['left'].set_color('none')
plt.yticks([])
plt.ylabel('')
plt.xlim(0, tmax);

print('Firing rate {:.2f} Hz.').format(get_firing_rate(crt_res[-1],
                                                       spiking_example_data['params']['tmax']))

safe_save_fig('figs/spiking_simraster_realistic_adult', png=False)

In [ ]:
make_convergence_movie('figs/spiking_convergence_movie_small_tau.mov',
                       spiking_example_data['res'], spiking_example_data['params']['target'],
                       idxs=range(0, 600), length=12.0,
                       ymax=80.0)

Spiking example, constant inhibition, learning curve and raster plots


In [ ]:
with open('save/spiking_example_const_inh.pkl', 'rb') as inp:
    spiking_example_data = pickle.load(inp)

In [ ]:
fig = plt.figure(figsize=(3, 2))
axs = draw_convergence_plot(spiking_example_data['res'], plt.gca(), target_lw=2,
                        inset=True, inset_pos=[0.4, 0.4, 0.4, 0.4],
                        alpha=spiking_example_data['params']['plasticity_params'][0],
                        beta=spiking_example_data['params']['plasticity_params'][1],
                        tau_tutor=spiking_example_data['params']['tutor_rule_tau'],
                        target=spiking_example_data['params']['target'])

axs[0].set_ylim(0, 15);


safe_save_fig('figs/spiking_example_const_inh_learning_curve', png=False)

In [ ]:
plt.figure(figsize=(3, 1))
crt_res = spiking_example_data['res'][:5]
show_repetition_pattern([_['student_spike'] for _ in crt_res], idx=[1, 47, 23, 65, 78], ms=1.0)
plt.gca().spines['top'].set_color('none')
plt.gca().spines['right'].set_color('none')
plt.gca().spines['left'].set_color('none')
plt.yticks([])
plt.ylabel('')
plt.xlim(0, tmax);

print('Firing rate {:.2f} Hz.').format(get_firing_rate(crt_res[-1],
                                                       spiking_example_data['params']['tmax']))

In [ ]:
plt.figure(figsize=(3, 1))
crt_res = spiking_example_data['res'][-5:]
show_repetition_pattern([_['student_spike'] for _ in crt_res], idx=[1, 47, 23, 65, 78], ms=1.0)
plt.gca().spines['top'].set_color('none')
plt.gca().spines['right'].set_color('none')
plt.gca().spines['left'].set_color('none')
plt.yticks([])
plt.ylabel('')
plt.xlim(0, tmax);

print('Firing rate {:.2f} Hz.').format(get_firing_rate(crt_res[-1],
                                                       spiking_example_data['params']['tmax']))

In [ ]:
make_convergence_movie('figs/spiking_convergence_movie_const_inh.mov',
                       spiking_example_data['res'], spiking_example_data['params']['target'],
                       idxs=range(0, 200), length=4.0,
                       ymax=80.0)

Reinforcement example learning curves

Reinforcement learning curve, small tau


In [ ]:
with open('save/reinforcement_example_0ms.pkl', 'rb') as inp:
    reinf_shorttau = pickle.load(inp)

In [ ]:
plt.imshow(reinf_shorttau['res'][7500]['weights'], aspect='auto', interpolation='nearest',
           cmap='Blues', vmin=0, vmax=0.3)
plt.colorbar()

In [ ]:
plot_evolution(reinf_shorttau['res'],
               reinf_shorttau['args']['target'],
               reinf_shorttau['args']['dt'])

In [ ]:
fig = plt.figure(figsize=(3, 2))
axs = draw_convergence_plot(reinf_shorttau['res'][:-9], plt.gca(), target_lw=2,
                      inset=True,
                      alpha=reinf_shorttau['args']['plasticity_params'][0],
                      beta=reinf_shorttau['args']['plasticity_params'][1],
                      tau_tutor=reinf_shorttau['args']['tutor_rule_tau'],
                      target=reinf_shorttau['args']['target'],
                      inset_pos=[0.52, 0.45, 0.4, 0.4])

axs[0].set_xticks(range(0, 8001, 2000))
axs[0].set_ylim(0, 15);

axs[1].set_yticks(range(0, 81, 20));

safe_save_fig('figs/reinforcement_convergence_plot_small_tau', png=False)

In [ ]:
plt.figure(figsize=(3, 1))
crt_res = reinf_shorttau['res'][:50:10]
show_repetition_pattern([_['student_spike'] for _ in crt_res], idx=[1, 45, 75, 65, 57])
plt.gca().spines['top'].set_color('none')
plt.gca().spines['right'].set_color('none')
plt.gca().spines['left'].set_color('none')
plt.yticks([])
plt.ylabel('')
crt_tmax = reinf_shorttau['args']['tmax'];
plt.xlim(0, crt_tmax);

print('Firing rate {:.2f} Hz.').format(get_firing_rate(crt_res[-1], crt_tmax))

safe_save_fig('figs/reinforcement_simraster_juvenile', png=False)

In [ ]:
plt.figure(figsize=(3, 1))
crt_res = reinf_shorttau['res'][-50::10]
show_repetition_pattern([_['student_spike'] for _ in crt_res], idx=[1, 45, 75, 65, 57])
plt.gca().spines['top'].set_color('none')
plt.gca().spines['right'].set_color('none')
plt.gca().spines['left'].set_color('none')
plt.yticks([])
plt.ylabel('')
crt_tmax = reinf_shorttau['args']['tmax'];
plt.xlim(0, crt_tmax);

print('Firing rate {:.2f} Hz.').format(get_firing_rate(crt_res[-1], crt_tmax))

safe_save_fig('figs/reinforcement_simraster_adult', png=False)

In [ ]:
make_convergence_movie('figs/reinforcement_convergence_movie_small_tau.mov',
                       reinf_shorttau['res'], reinf_shorttau['args']['target'],
                       idxs=range(0, 10000), length=10.0,
                       ymax=80.0)

Reinforcement learning curve, long tau


In [ ]:
with open('save/reinforcement_example_a10b9_440ms.pkl', 'rb') as inp:
    reinf_longtau = pickle.load(inp)

In [ ]:
fig = plt.figure(figsize=(3, 2))
axs = draw_convergence_plot(reinf_longtau['res'][:-9], plt.gca(), target_lw=2,
                      inset=True,
                      alpha=reinf_longtau['args']['plasticity_params'][0],
                      beta=reinf_longtau['args']['plasticity_params'][1],
                      tau_tutor=reinf_longtau['args']['tutor_rule_tau'],
                      target=reinf_longtau['args']['target'],
                      inset_pos=[0.5, 0.45, 0.4, 0.4])

axs[0].set_xticks(range(0, 8001, 2000))
axs[0].set_ylim(0, 15);

axs[1].set_yticks(range(0, 81, 20));

safe_save_fig('figs/reinforcement_convergence_plot_large_tau', png=False)

In [ ]:
plt.figure(figsize=(3, 1))
crt_res = reinf_longtau['res'][:50:10]
show_repetition_pattern([_['student_spike'] for _ in crt_res], idx=[3, 48, 19, 62, 78], ms=1.0)
plt.gca().spines['top'].set_color('none')
plt.gca().spines['right'].set_color('none')
plt.gca().spines['left'].set_color('none')
plt.yticks([])
plt.ylabel('')
crt_tmax = reinf_longtau['args']['tmax'];
plt.xlim(0, crt_tmax);

print('Firing rate {:.2f} Hz.').format(get_firing_rate(crt_res[-1], crt_tmax))

In [ ]:
plt.figure(figsize=(3, 1))
crt_res = reinf_longtau['res'][-50::10]
show_repetition_pattern([_['student_spike'] for _ in crt_res], idx=[3, 48, 19, 62, 78], ms=1.0)
plt.gca().spines['top'].set_color('none')
plt.gca().spines['right'].set_color('none')
plt.gca().spines['left'].set_color('none')
plt.yticks([])
plt.ylabel('')
crt_tmax = reinf_longtau['args']['tmax'];
plt.xlim(0, crt_tmax);

print('Firing rate {:.2f} Hz.').format(get_firing_rate(crt_res[-1], crt_tmax))

In [ ]:
make_convergence_movie('figs/reinforcement_convergence_movie_large_tau.mov',
                       reinf_longtau['res'], reinf_longtau['args']['target'],
                       idxs=range(0, 10000), length=10.0,
                       ymax=80.0)

Reinforcement learning, evolution of synapse sparsity


In [ ]:
with open('save/reinforcement_example_0ms.pkl', 'rb') as inp:
    reinf_shorttau = pickle.load(inp)

In [ ]:
motor_idxs = range(0, len(reinf_shorttau['res']), 50)
weight_sparsity = [np.sum(reinf_shorttau['res'][_]['weights'] > 0.01)/
                   (reinf_shorttau['args']['n_student_per_output']*len(reinf_shorttau['args']['target']))
                   for _ in motor_idxs]

In [ ]:
plt.figure(figsize=(3, 2))
plt.plot(motor_idxs, weight_sparsity, color=[0.200, 0.357, 0.400])
plt.xlabel('repetitions')
plt.ylabel('HVC inputs per RA neuron')
plt.ylim(0, 200);
plt.grid(True)
safe_save_fig('figs/inputs_per_ra_evolution_reinf')

In [ ]: