Testing tutor-student matching with rate-based simulations


In [ ]:
%matplotlib inline
import matplotlib as mpl
import matplotlib.ticker as mtick
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('white')

plt.rc('text', usetex=True)
plt.rc('font', family='serif', serif='cm')

plt.rcParams['figure.titlesize'] = 10
plt.rcParams['axes.labelsize'] = 8
plt.rcParams['axes.titlesize'] = 10
plt.rcParams['xtick.labelsize'] = 8
plt.rcParams['ytick.labelsize'] = 8
plt.rcParams['axes.labelpad'] = 3.0

from IPython.display import display, clear_output
from ipywidgets import FloatProgress

# comment out the next line if not working on a retina-display computer
import IPython
IPython.display.set_matplotlib_formats('retina')

In [ ]:
import numpy as np
import copy
import time
import os
import cPickle as pickle

In [ ]:
import simulation
from basic_defs import *
from helpers import *

Define target motor programs


In [ ]:
tmax = 600.0                  # duration of motor program (ms)
dt = 1.0                      # simulation timestep (ms)

nsteps = int(tmax/dt)
times = np.arange(0, tmax, dt)

In [ ]:
# add some noise, but keep things reproducible
np.random.seed(0)
target_complex = 100.0*np.vstack((
            np.convolve(np.sin(times/100 + 0.1*np.random.randn(len(times)))**6 +
                      np.cos(times/150 + 0.2*np.random.randn(len(times)) + np.random.randn())**4,
                  np.exp(-0.5*np.linspace(-3.0, 3.0, 200)**2)/np.sqrt(2*np.pi)/80, mode='same'),
            np.convolve(np.sin(times/110 + 0.15*np.random.randn(len(times)) + np.pi/3)**6 +
                      np.cos(times/100 + 0.2*np.random.randn(len(times)) + np.random.randn())**4,
                  np.exp(-0.5*np.linspace(-3.0, 3.0, 200)**2)/np.sqrt(2*np.pi)/80, mode='same'),
        ))

In [ ]:
# or start with something simple: constant target
target_const = np.vstack((70.0*np.ones(len(times)), 50.0*np.ones(len(times))))

In [ ]:
# or something simple but not trivial: steps
target_piece = np.vstack((
    np.hstack((20.0*np.ones(len(times)/2), 100.0*np.ones(len(times)/2))),
    np.hstack((60.0*np.ones(len(times)/2), 30.0*np.ones(len(times)/2)))))

In [ ]:
targets = {'complex': target_complex, 'piece': target_piece, 'constant': target_const}

Choose target


In [ ]:
# choose one target
target_choice = 'complex'
#target_choice = 'constant'

In [ ]:
target = copy.copy(targets[target_choice])

# make sure the target smoothly goes to zero at the edges
# this is to match the spiking simulation, which needs some time to ramp
# up in the beginning and time to ramp down at the end
edge_duration = 100.0 # ms
edge_len = int(edge_duration/dt)
tapering_x = np.linspace(0.0, 1.0, edge_len, endpoint=False)
tapering = (3 - 2*tapering_x)*tapering_x**2
target[:, :edge_len] *= tapering
target[:, -edge_len:] *= tapering[::-1]

General definitions

Here we define some classes and functions that will be used to run all the simulations we are interested in.


In [ ]:
class ProgressBar(object):
    
    """ A callable that displays a widget progress bar and can also make a plot showing
    the learning trace.
    """
    
    def __init__(self, simulator, show_graph=True, graph_step=50, max_error=1000):
        self.t0 = None
        self.float = None
        self.show_graph = show_graph
        self.graph_step = graph_step
        self.simulator = simulator
        self.max_error = max_error
        self.print_last = True
    
    def __call__(self, i, n):
        t = time.time()
        if self.t0 is None:
            self.t0 = t
        t_diff = t - self.t0
        
        current_res = self.simulator._current_res
        
        text = 'step: {} ; time elapsed: {:.1f}s'.format(i, t_diff)
        
        if len(current_res) > 0:
            last_error = current_res[-1]['average_error']
            if last_error <= self.max_error:
                text += ' ; last error: {:.2f}'.format(last_error)
            else:
                text += ' ; last error: very large'
        if self.float is None:
            self.float = FloatProgress(min=0, max=100)
            display(self.float)
        else:
            percentage = min(round(i*100.0/n), 100)
            self.float.value = percentage
            
        self.float.description = text
            
        if self.show_graph and i % self.graph_step == 0:
            crt_res = [_['average_error'] for _ in current_res]
            plt.plot(range(len(crt_res)), crt_res, '.-k')
            plt.xlim(0, n-1)
            plt.xlabel('repetition')
            plt.ylabel('error')

            if len(crt_res) > 0:
                if i < 100:
                    plt.ylim(np.min(crt_res) - 0.1, np.max(crt_res) + 0.1)
                else:
                    plt.ylim(0, np.max(crt_res))
            else:
                plt.ylim(0, 1)

            clear_output(wait=True)
            if i < n:
                display(plt.gcf())
        
        if i == n:
            self.float.close()
            if self.print_last:
                print(text)

In [ ]:
# this defines the basic class used for running the rate-based simulations

class RateLearningSimulation(object):
    
    """ A class that runs the rate-based simulation for several learning cycles. """
    
    def __init__(self, target, tmax, dt, n_conductor, n_student_per_output,
                           relaxation=400.0, relaxation_conductor=25.0,
                           tracker_generator=None, snapshot_generator=None,
                           conductor_burst_length=None,
                           conductor_from_table=None,
                           controller_mode='sum', controller_tau=25.0,
                           controller_mismatch_type='random', controller_mismatch_amount=0,
                           controller_error_map_function=None,
                           controller_mismatch_subdivide_by=1,
                           controller_nonlinearity=None,
                           tutor_rule_type='blackbox', tutor_rule_tau=0.0,
                           tutor_rule_gain=None, tutor_rule_gain_per_student=0.5,
                           tutor_rule_compress_rates=False,
                           tutor_rule_min_rate=0.0, tutor_rule_max_rate=160.0,
                           cs_weights_type='lognormal', cs_weights_params=(-3.57, 0.54),
                           cs_weights_scale=1.0, ts_weights=0.02,
                           plasticity_type='2exp',
                           plasticity_learning_rate=0.002, plasticity_params=(1.0, 0.0),
                           plasticity_taus=(80.0, 40.0), plasticity_constrain_positive=False):
        """ Run the simulation for several learning cycles.

        Arguments
        ---------
          target: array (shape (Nmuscles, Nsteps))
              Target output program.
          tmax:
          dt: float
              Length and granularity of target program. `tmax` should be equal to `Nsteps * dt`,
              where `Nsteps` is the number of columns of the `target` (see above).
          n_conductor: int
              Number of conductor neurons.
          n_student_per_output: int
              Number of student neurons per output channel. If `controller_mode` is not 'pushpull',
              the actual number of student neurons is `n_student_per_output * Nmuscles`, where
              `Nmuscles` is the number of rows of `target` (see above). If `controller_mode` is
              `pushpull`, this is further multiplied by 2.
          relaxation: float
              Length of time that the simulation runs past the end of the `target`. This ensures
              that all the contributions from the plasticity rule are considered.
          relaxation_conductor: float
              Length of time that the conductor fires past the end of the `target`. This is to avoid
              excessive decay at the end of the program.
          tracker_generator: callable
              This function is called before every simulation run with the signature
                  `tracker_generator(simulator, i n)`
              where `simulator` is the object running the simulations (i.e., `self`), `i` is the
              index of the current learning cycle, and `n` is the total number of learning cycles
              that will be simulated. The function should return a dictionary of objects such as
              `StateMonitor`s and `EventMonitor`s that track the system during the simulation.
              These objects will be returned in the results output structure after the run (see
              the `run` method).
          snapshot_generator: callable
              This can be a function or a pair of functions. If it is a single function, it is called
              before every simulation run with the signature
                  `snapshot_generator(simulator, i n)`
              where `simulator` is the object running the simulations (i.e., `self`), `i` is the
              index of the current learning cycle, and `n` is the total number of learning cycles
              that will be simulated. The function should return a dictionary that will be appended
              directly to the results output structure after the run (see the `run` method). This can
              be used to make snapshots of various structures, such as the conductor--student weights,
              during learning.
              
              When this is a pair, both elements should be functions with the same signature as shown
              above (or `None`). The first will be called before the simulation run, and the second
              after.
          conductor_burst_length: float, or None
              Duration of conductor bursts. Set to `None` to have the next burst start where the previous
              one ends.
          conductor_from_table: None, or matrix (n_conductor x n_time_slices)
              If not `None`, instead of using the `RateHVCLayer`, use the given table for the outputs
              of conductor neurons as a function of time. Each column of the table corresponds to one
              time 'slice', which has length given by `conductor_burst_length` (in ms).
          controller_mode: str
              The way in which the student--output weights should be initialized. This can be 'sum' or
              'pushpull' (see `LinearController.__init__` for details).
          controller_tau: float
              Timescale for smoothing of output.
          controller_mismatch_type: str
              Method used to simulate credit assignment mismatch. Only possible option for now is
              'random'.
          controller_mismatch_amount: float
              Fraction of student neurons whose output assignment is mismatched when the motor error
              calculation is performed (this is used by the blackbox tutor rule).
          controller_mismatch_subdivide_by: int
              Number of subdivisions for each controller channel when performing the random mismatch.
              Assignments between different subgroups can get shuffled as if they belonged to
              different outputs.
          controller_error_map_function: None, or function
              If not `None`, use a (nonlinear) function to map the motor error into source error. This
              can be used to handle non-quadratic loss functions (see `LinearController`).
          controller_nonlinearity: None, or function
              If not `None`, use a (nonlinear) function to map the weighted input to the output. This
              can be used to implement linear-nonlinear controllers (see `LinearController`).
          tutor_rule_type: str
              Type of tutor rule to use. Currently this should be set to 'blackbox'.
          tutor_rule_tau: float
              Integration timescale for tutor signal (see `BlackboxTutorRule`).
          tutor_rule_gain: float, or `None`
              If not `None`, sets the gain for the blackbox tutor rule (see `BlackboxTutorRule`).
              Either this or `tutor_rule_gain_per_student` should be non-`None`. 
          tutor_rule_gain_per_student: float, or `None`
              If not `None`, the gain for the blackbox tutor rule (see `BlackboxTutorRule`) is set
              proportional to the number of student neurons per output channel, `n_student_per_channel`.
          tutor_rule_compress_rates: bool
              Sets the `compress_rates` option for the blacktox tutor rule (see `BlackboxTutorRule`).
          tutor_rule_min_rate: float
          tutor_rule_max_rate: float
              Sets the minimum and maximum rate for the tutor rule (see `BlackboxTutorRule`).
          cs_weights_type: str
          cs_weights_params:
              Sets the way in which the conductor--student weights should be initialized. This can be
                'zero':       set the weights to zero
                'constant':   set all the weights equal to `cs_weights_params`
                'normal':     use Gaussian random variables, parameters (mean, st.dev.) given by
                              `cs_weights_params`
                'lognormal':  use log-normal random variables, parameters (mu, sigma) given by
                              `cs_weights_params`
          cs_weights_scale: float
              Set a scaling factor for all the conductor--student weights. This is applied after the
              weights are calculated according to `cs_weights_type` and `cs_weights_params`.
          ts_weights: float
              The value of the tutor--student synaptic strength.
          plasticity_type: str
              Type of plasticity rule to use:
                '2exp':     use `TwoExponentialsPlasticity`
                'exp_texp': use `SuperExponentialPlasticity`
          plasticity_learning_rate: float
              The learning rate of the plasticity rule (see `TwoExponentialsPlasticity`).
          plasticity_params: (alpha, beta)
              The parameters used by the plasticity rule (see `TwoExponentialsPlasticity`).
          plasticity_taus: (tau1, tau2)
              The timescales used by the plasticity rule (see `TwoExponentialsPlasticity`).
          plasticity_constrain_positive: bool
              Whether to keep conductor--student weights non-negative or not
              (see `TwoExponentialsPlasticity`).
        """
        self.target = np.asarray(target)
        self.tmax = float(tmax)
        self.dt = float(dt)
        
        self.n_muscles = len(self.target)
        
        self.n_conductor = n_conductor
        self.n_student_per_output = n_student_per_output
        
        self.relaxation = relaxation
        self.relaxation_conductor = relaxation_conductor
        
        self.tracker_generator = tracker_generator
        self.snapshot_generator = snapshot_generator
        
        if not hasattr(self.snapshot_generator, '__len__'):
            self.snapshot_generator = (self.snapshot_generator, None)
        
        self.conductor_burst_length = conductor_burst_length
        self.conductor_from_table = conductor_from_table
        
        self.controller_mode = controller_mode
        self.controller_tau = controller_tau
        self.controller_mismatch_type = controller_mismatch_type
        self.controller_mismatch_amount = controller_mismatch_amount
        self.controller_mismatch_subdivide_by = controller_mismatch_subdivide_by
        self.controller_error_map_function = controller_error_map_function
        self.controller_nonlinearity = controller_nonlinearity
        
        self.tutor_rule_type = tutor_rule_type
        self.tutor_rule_tau = tutor_rule_tau
        self.tutor_rule_gain = tutor_rule_gain
        self.tutor_rule_gain_per_student = tutor_rule_gain_per_student
        
        self.tutor_rule_compress_rates = tutor_rule_compress_rates
        self.tutor_rule_min_rate = tutor_rule_min_rate
        self.tutor_rule_max_rate = tutor_rule_max_rate
        
        self.cs_weights_type = cs_weights_type
        self.cs_weights_params = cs_weights_params
        self.cs_weights_scale = cs_weights_scale
        
        self.ts_weights = ts_weights
        
        self.plasticity_type = plasticity_type
        self.plasticity_learning_rate = plasticity_learning_rate
        self.plasticity_params = plasticity_params
        self.plasticity_taus = plasticity_taus
        self.plasticity_constrain_positive = plasticity_constrain_positive
        
        self.progress_indicator = ProgressBar(self)
        
        self.setup()
    
    def setup(self):
        """ Create the components of the simulation. """
        # process some of the options
        self.n_student = self.n_student_per_output*self.n_muscles
        
        if self.controller_mode == 'pushpull':
            self.n_student *= 2
        
        if self.tutor_rule_gain is None:
            self.tutor_rule_actual_gain = self.tutor_rule_gain_per_student*self.n_student_per_output
        else:
            self.tutor_rule_actual_gain = self.tutor_rule_gain
        
        self.total_time = self.tmax + self.relaxation
        self.stimes = np.arange(0, self.total_time, self.dt)
        
        self._current_res = []
        
        # build components
        if self.conductor_from_table is None:
            self.conductor = RateHVCLayer(self.n_conductor, burst_tmax=self.tmax+self.relaxation_conductor,
                                          burst_length=self.conductor_burst_length)
        else:
            rep_count = (1 if self.conductor_burst_length is None else
                         int_r(self.conductor_burst_length/self.dt))
            table = np.repeat(self.conductor_from_table, rep_count, axis=1)
            self.conductor = TableLayer(table)
            
        self.student = RateLayer(self.n_student)
        self.motor = LinearController(self.student, self.target,
                                      mode=self.controller_mode, tau=self.controller_tau)
        if self.controller_mismatch_amount > 0:
            if self.controller_mismatch_type != 'random':
                raise Exception('Unknown controller_mismatch_type '+
                                str(self.controller_mismatch_type) + '.')
            self.motor.set_random_permute_inverse(self.controller_mismatch_amount,
                                                  subdivide_by=self.controller_mismatch_subdivide_by)
        self.motor.error_map_fct = self.controller_error_map_function
        self.motor.nonlinearity = self.controller_nonlinearity
        
        if self.tutor_rule_type != 'blackbox':
            raise Exception('Unknown tutor_rule_type ' + str(self.tutor_rule_type) + '.')
        
        self.tutor_rule = BlackboxTutorRule(self.motor, tau=self.tutor_rule_tau,
                                            gain=self.tutor_rule_actual_gain,
                                            compress_rates=self.tutor_rule_compress_rates,
                                            min_rate=self.tutor_rule_min_rate,
                                            max_rate=self.tutor_rule_max_rate)
        # the blackbox tutor rule will wind down its activity during relaxation time
        self.tutor_rule.relaxation = self.relaxation

        # add synapses to student
        self.student.add_source(self.conductor)
        self.student.add_source(self.tutor_rule)
        
        # generate the conductor--student weights
        self.init_cs_weights()
        
        # set tutor--student weights
        self.student.Ws[1] = self.ts_weights*np.ones(self.n_student)
        self.student.bias = -self.ts_weights*(self.tutor_rule_min_rate + self.tutor_rule_max_rate)/2.0
        
        # initialize the plasiticity rule
        if self.plasticity_type == '2exp':
            self.plasticity = TwoExponentialsPlasticity(
                (self.conductor, self.student, self.student.Ws[0]), self.tutor_rule,
                rate=self.plasticity_learning_rate,
                alpha=self.plasticity_params[0], beta=self.plasticity_params[1],
                tau1=self.plasticity_taus[0], tau2=self.plasticity_taus[1],
                constrain_positive=self.plasticity_constrain_positive)
        elif self.plasticity_type == 'exp_texp':
            self.plasticity = SuperExponentialPlasticity(
                (self.conductor, self.student, self.student.Ws[0]), self.tutor_rule,
                rate=self.plasticity_learning_rate,
                alpha=self.plasticity_params[0], beta=self.plasticity_params[1],
                tau1=self.plasticity_taus[0], tau2=self.plasticity_taus[1],
                constrain_positive=self.plasticity_constrain_positive)
        else:
            raise Exception('Unknown plasticity_type ' + str(self.plasticity_type) + '.')

    def init_cs_weights(self):
        """ Initialize conductor--student weights. """
        if self.cs_weights_type == 'zero':
            self.student.Ws[0] = np.zeros((self.n_student, self.n_conductor))
        elif self.cs_weights_type == 'constant':
            self.student.Ws[0] = self.cs_weights_params*np.ones((self.n_student, self.n_conductor))
        elif self.cs_weights_type == 'normal':
            self.student.Ws[0] = (self.cs_weights_params[0] +
                                  self.cs_weights_params[1]*np.random.randn(self.n_student, self.n_conductor))
        elif self.cs_weights_type == 'lognormal':
            self.student.Ws[0] = np.random.lognormal(*self.cs_weights_params,
                                                     size=(self.n_student, self.n_conductor))
        
        self.student.Ws[0] *= self.cs_weights_scale

    def run(self, n_runs):
        """ Run the simulation for `n_runs` learning cycles.
        
        This function intercepts `KeyboardInterrupt` exceptions and returns the results up to
        the time of the keyboard intercept.
        """
        res = []
        
        self._current_res = res
                
        try:
            for i in xrange(n_runs):
                if self.progress_indicator is not None:
                    self.progress_indicator(i, n_runs)
                
                # make the pre-run snapshots
                if self.snapshot_generator[0] is not None:
                    snaps_pre = self.snapshot_generator[0](self, i, n_runs)
                else:
                    snaps_pre = {}
                
                # get the trackers
                if self.tracker_generator is not None:
                    trackers = self.tracker_generator(self, i, n_runs)
                    if trackers is None:
                        trackers = {}
                else:
                    trackers = {}
                                
                # no matter what, we will need an error tracker to calculate average error
                M_merr = MotorErrorTracker(self.motor)
                
                # create and run the simulation
                sim = simulation.Simulation(self.conductor, self.student, self.tutor_rule,
                                            self.motor, self.plasticity,
                                            M_merr, *trackers.values(), dt=self.dt)
                sim.run(self.total_time)
                
                # make the post-run snapshots
                if self.snapshot_generator[1] is not None:
                    snaps_post = self.snapshot_generator[1](self, i, n_runs)
                else:
                    snaps_post = {}
                
                crt_res = {'average_error': np.mean(M_merr.overall_error)}
                crt_res.update(snaps_pre)
                crt_res.update(snaps_post)
                crt_res.update(trackers)
                
                res.append(crt_res)
                
            if self.progress_indicator is not None:
                self.progress_indicator(n_runs, n_runs)
        except KeyboardInterrupt:
            pass
        
        return res

In [ ]:
# sometimes we need to run a set of simulations in which some parameters change

def simulate_many(constant_params, variable_params):
    """ Run a set of simulations.
    
    Note that each run must have 'target', 'tmax', 'dt', and 'nreps' entries in the dictionary,
    either coming from `constant_params` or from `variable_params`. The special entries
    'graph_step' and 'show_graph' can be used to control the progress indicator for the
    `RateLearningSimulation`s.
    
    Arguments
    ---------
      constant_params: dict
          Dictionary holding those parameters that are constant for all simulations.
      variable_params: array of dict
          Array of dictionary holding the parameters that change between different simulations.
          This can have any number of dimensions, and the output will match its shape.
    
    Returns
    -------
      res_array: array of dict
          Array of results for each of the simulations. Each dict has to entries,
            'params':       a dictionary of containing all the parameters for the simulation
            'trace':        the data resulting from that simulation, as returned by the 
                            `RateLearningSimulation` object
            'error_trace':  the error trace (the 'average_error' values for all the entries in
                            `trace`)
          This has the same shape as the `variable_params` argument.
    """    
    # initialize the output
    variable_params = np.asarray(variable_params)
    res_array = np.empty(np.shape(variable_params), dtype='object')
    
    n_sims = np.size(variable_params)
    
    # display a progress bar
    t0 = time.time()
    bar = FloatProgress(min=0, max=100)
    display(bar)
    
    # process the data
    for i in xrange(n_sims):
        t = time.time()
        t_diff = t - t0
        
        text = 'simulation #{} ; time elapsed: {:.1f}s'.format(i, t_diff)
        
        percentage = min(round(i*100.0/n_sims), 100)
        bar.value = percentage 
        bar.description = text
        
        current_params = dict(constant_params)
        current_params.update(variable_params.ravel()[i])
        
        current_target = current_params.pop('target')
        current_tmax = current_params.pop('tmax')
        current_dt = current_params.pop('dt')
        current_n_reps = current_params.pop('n_reps')
        
        current_graph_step = current_params.pop('graph_step', 50)
        current_show_graph = current_params.pop('show_graph', True)
        
        clear_output(wait=True)
        plt.clf()
        
        current_sim = RateLearningSimulation(
            current_target, current_tmax, current_dt, **current_params)
        current_sim.progress_indicator.graph_step = current_graph_step
        current_sim.progress_indicator.show_graph = current_show_graph
        current_sim.progress_indicator.print_last = False
        
        current_res = current_sim.run(current_n_reps)
        
        # don't keep the functions in the parameters
        current_params.pop('tracker_generator', None)
        current_params.pop('snapshot_generator', None)
        
        # re-add n_reps, tmax, and dt; not target, because that can be too big
        current_params['tmax'] = current_tmax
        current_params['dt'] = current_dt
        current_params['n_reps'] = current_n_reps
        
        current_details = {
            'params': current_params,
            'trace': current_res,
            'error_trace': np.asarray([_['average_error'] for _ in current_res])}
        
        res_array.ravel()[i] = current_details
        
    bar.value = 100.0
    t = time.time()
    t_diff = t - t0

    bar.description = 'simulation done ; time elapsed: {:.1f}s'.format(t_diff)
    
    return res_array

In [ ]:
# tracker functions that will be used throughout the code

def tracker_generator(simulator, i, n):
    """ Generate some trackers. """
    res = {}
    res['motor'] = simulation.StateMonitor(simulator.motor, 'out')
    res['tutor'] = simulation.StateMonitor(simulator.tutor_rule, 'out')
    res['conductor'] = simulation.StateMonitor(simulator.conductor, 'out')
    res['student'] = simulation.StateMonitor(simulator.student, 'out')
    
    return res

def snapshot_generator_pre(simulator, i, n):
    """ Generate some pre-run snapshots. """
    res = {}
    res['weights'] = np.copy(simulator.student.Ws[0])
    
    return res

Create the data for the mismatch matrix

Some generic code

We first define a function that can calculate this for a variety of conditions, such as 'sum' or 'push pull' motor controllers, constraining conductor--student weights to be positive, etc. Then we use this function for all the cases of interest.


In [ ]:
def generate_mismatch_data(tau_levels, n_reps, target, tmax, dt, **params):
    """ Generate a matrix of results showing the effect of tutor-student mismatch.
    
    The extra arguments give a dictionary of parameters that override the defaults. These are assumed
    to be constant over all the simulations. Here the plasticity parameters are constrained
    such that $alpha - beta = 1$.
    """
    # this is all we're tracking
    def tracker_generator(simulator, i, n):
        """ Generate some trackers. """
        res = {}
        res['motor'] = simulation.StateMonitor(simulator.motor, 'out')
    
        return res
    
    def snapshot_generator(simulator, i, n):
        res = {}
        res['weights'] = np.copy(simulator.student.Ws[0])
        
        return res
    
    # update the default parameters using the arguments
    default_params = dict(
        show_graph=False,
        n_conductor=100, n_student_per_output=1,
        relaxation=400.0, relaxation_conductor=25.0,
        conductor_burst_length=None, controller_mode='sum', controller_tau=25.0,
        tutor_rule_gain_per_student=0.5, tutor_rule_compress_rates=False,
        tutor_rule_min_rate=0.0, tutor_rule_max_rate=160.0,
        cs_weights_type='lognormal', cs_weights_params=(-3.57, 0.54), cs_weights_scale=200.0,
        ts_weights=0.01,
        plasticity_learning_rate=0.001,
        plasticity_taus=(80.0, 40.0),
        plasticity_constrain_positive=False,
        tracker_generator=tracker_generator,
        snapshot_generator=snapshot_generator
    )
    actual_params = dict(default_params)
    actual_params.update(params)

    # calculate the plasticity parameters for each tau
    tau1, tau2 = actual_params['plasticity_taus']
    
    # fix alpha - beta = 1
    plasticity_levels = [(lambda _: (_, _-1))(float(tau - tau2)/(tau1 - tau2)) for tau in tau_levels]
        
    actual_params['target'] = target
    actual_params['tmax'] = tmax
    actual_params['dt'] = dt
    actual_params['n_reps'] = n_reps
    
    n_levels = len(tau_levels)
    return simulate_many(
        actual_params,
        [[dict(
                tutor_rule_tau=current_tau_tutor,
                plasticity_params=current_plastic
            ) for current_tau_tutor in tau_levels] for current_plastic in plasticity_levels]
    )

Generate mismatch matrix for 'sum' controller


In [ ]:
# reproducible randomness
np.random.seed(23782)

res_mismatch = generate_mismatch_data(10*2**np.arange(8), 250, target, tmax, dt)

In [ ]:
file_name = 'save/rate_based_results_sum_log_8.pkl'
if not os.path.exists(file_name):
    with open(file_name, 'wb') as out:
        pickle.dump({'res_mismatch': res_mismatch, 'target': target, 'tmax': tmax, 'dt': dt}, out, 2)
else:
    raise Exception('File exists!')

Generate bigger mismatch matrix for 'sum' controller


In [ ]:
# reproducible randomness
np.random.seed(23782)

res_mismatch = generate_mismatch_data(10*2**np.arange(12), 1000, target, tmax, dt,
                                      relaxation=1200.0, relaxation_conductor=50.0)

In [ ]:
file_name = 'save/rate_based_results_sum_log_12.pkl'
if not os.path.exists(file_name):
    with open(file_name, 'wb') as out:
        pickle.dump({'res_mismatch': res_mismatch, 'target': target, 'tmax': tmax, 'dt': dt}, out, 2)
else:
    raise Exception('File exists!')

Generate mismatch matrix for 'pushpull' controller


In [ ]:
# reproducible randomness
np.random.seed(123134)

res_mismatch = generate_mismatch_data(10*2**np.arange(8), 250, target, tmax, dt,
                                      controller_mode='pushpull')

In [ ]:
file_name = 'save/rate_based_results_pushpull_log_8.pkl'
if not os.path.exists(file_name):
    with open(file_name, 'wb') as out:
        pickle.dump({'res_mismatch': res_mismatch, 'target': target, 'tmax': tmax, 'dt': dt}, out, 2)
else:
    raise Exception('File exists!')

Generate mismatch matrix when weights are constrained positive

With 'sum' controller.


In [ ]:
# reproducible randomness
np.random.seed(34234)

res_mismatch = generate_mismatch_data(10*2**np.arange(8), 250, target, tmax, dt,
                                      plasticity_constrain_positive=True)

In [ ]:
file_name = 'save/rate_based_results_posweights_log_8.pkl'
if not os.path.exists(file_name):
    with open(file_name, 'wb') as out:
        pickle.dump({'res_mismatch': res_mismatch, 'target': target, 'tmax': tmax, 'dt': dt}, out, 2)
else:
    raise Exception('File exists!')

Generate mismatch matrix when $\tau_{1,2}$ are small


In [ ]:
# reproducible randomness
np.random.seed(4324)

res_mismatch = generate_mismatch_data(10*2**np.arange(8), 250, target, tmax, dt,
                                      plasticity_taus=(20.0, 10.0))

In [ ]:
file_name = 'save/rate_based_results_small_tau_log_8.pkl'
if not os.path.exists(file_name):
    with open(file_name, 'wb') as out:
        pickle.dump({'res_mismatch': res_mismatch, 'target': target, 'tmax': tmax, 'dt': dt}, out, 2)
else:
    raise Exception('File exists!')

Generate mismatch matrix when $\tau_{1,2}$ are large


In [ ]:
# reproducible randomness
np.random.seed(76476)

res_mismatch = generate_mismatch_data(10*2**np.arange(8), 250, target, tmax, dt,
                                      plasticity_taus=(160.0, 80.0))

In [ ]:
file_name = 'save/rate_based_results_large_tau_log_8.pkl'
if not os.path.exists(file_name):
    with open(file_name, 'wb') as out:
        pickle.dump({'res_mismatch': res_mismatch, 'target': target, 'tmax': tmax, 'dt': dt}, out, 2)
else:
    raise Exception('File exists!')

Create credit mis-assignment data

Some generic code for credit mis-assignment

We first define a function that can calculate this for a variety of conditions, such as 'sum' or 'push pull' motor controllers, constraining conductor--student weights to be positive, etc. Then we use this function for all the cases of interest.


In [ ]:
def generate_credit_mismatch_data(rho_levels, n_reps, target, tmax, dt, **params):
    """ Generate a vector of results showing the effect of credit assignment mismatch.
    
    `rho_levels` gives the fraction of student neurons who output assignment will be
    mismatched. The output will have the same shape as `rho_levels`. The extra arguments
    give a dictionary of parameters that override the defaults. These are assumed to be
    constant over all the simulations.
    """
    # this is all we're tracking
    def tracker_generator(simulator, i, n):
        """ Generate some trackers. """
        res = {}
        res['motor'] = simulation.StateMonitor(simulator.motor, 'out')
    
        return res

    # update the default parameters using the arguments
    default_params = dict(
        show_graph=False,
        n_conductor=100, n_student_per_output=40,
        relaxation=400.0, relaxation_conductor=25.0,
        conductor_burst_length=None, controller_mode='sum', controller_tau=25.0,
        tutor_rule_gain_per_student=0.5, tutor_rule_compress_rates=False,
        tutor_rule_tau=40.0,
        tutor_rule_min_rate=0.0, tutor_rule_max_rate=160.0,
        cs_weights_type='lognormal', cs_weights_params=(-3.57, 0.54), cs_weights_scale=200.0,
        ts_weights=0.01,
        plasticity_learning_rate=0.001,
        plasticity_taus=(80.0, 40.0),
        plasticity_params=(0.0, -1),
        plasticity_constrain_positive=False,
        tracker_generator=tracker_generator,
        snapshot_generator=None
    )
    actual_params = dict(default_params)
    actual_params.update(params)
    
    actual_params['target'] = target
    actual_params['tmax'] = tmax
    actual_params['dt'] = dt
    actual_params['n_reps'] = n_reps
    
    return simulate_many(
        actual_params,
        [dict(
                controller_mismatch_amount=current_rho
            ) for current_rho in rho_levels]
    )

Generate credit mismatch data for 'sum' controller


In [ ]:
# reproducible randomness
np.random.seed(23872)

res_mismatch = generate_credit_mismatch_data(np.arange(0.0, 1.025, 0.025), 250, target, tmax, dt,
                                             controller_mode='sum', snapshot_generator=None)

In [ ]:
file_name = 'save/rate_based_credit_results_sum.pkl'
if not os.path.exists(file_name):
    with open(file_name, 'wb') as out:
        pickle.dump({'res_mismatch': res_mismatch, 'target': target, 'tmax': tmax, 'dt': dt}, out, 2)
else:
    raise Exception('File exists!')

Generate credit mismatch data for 'pushpull' controller, no subdivision


In [ ]:
# reproducible randomness
np.random.seed(8372847)

res_mismatch = generate_credit_mismatch_data(np.arange(0.0, 1.025, 0.025), 250, target, tmax, dt,
                                             controller_mode='pushpull')

In [ ]:
file_name = 'save/rate_based_credit_results_pushpull.pkl'
if not os.path.exists(file_name):
    with open(file_name, 'wb') as out:
        pickle.dump({'res_mismatch': res_mismatch, 'target': target, 'tmax': tmax, 'dt': dt}, out, 2)
else:
    raise Exception('File exists!')

Generate credit mismatch data for 'pushpull' controller, subdivide by 2

Here we allow the excitatory and inhibitory contributions within the same output channel to be mismatched.


In [ ]:
# reproducible randomness
np.random.seed(8372847)

res_mismatch = generate_credit_mismatch_data(np.arange(0.0, 1.025, 0.025), 250, target, tmax, dt,
                                             controller_mode='pushpull',
                                             controller_mismatch_subdivide_by=2)

In [ ]:
file_name = 'save/rate_based_credit_results_pushpull_subdiv.pkl'
if not os.path.exists(file_name):
    with open(file_name, 'wb') as out:
        pickle.dump({'res_mismatch': res_mismatch, 'target': target, 'tmax': tmax, 'dt': dt}, out, 2)
else:
    raise Exception('File exists!')

Make figures


In [ ]:
def compare_traces(res1, res2, target, colors=[[0.200, 0.357, 0.400], [0.831, 0.333, 0.000]],
                   labels=None, ymax=None, inset_ymax=None,
                   inset_label_size=8, inset_legend_pos=(1.1, 1.1)):
    """ Make a plot comparing two convergence traces.
    
    This shows the evolution of the error in the main plot, and a comparison of the final
    output in an inset.
    """
    if labels is None:
        labels = [None, None]
        
    plt.plot([_['average_error'] for _ in res1], c=colors[0], label=labels[0])
    plt.plot([_['average_error'] for _ in res2], c=colors[1], label=labels[1])
    
    plt.xlabel('repetition')
    plt.ylabel('error')
    
    if ymax is not None:
        plt.ylim(0, ymax)
    else:
        plt.ylim(0, plt.ylim()[1])

#    if any(_ is not None for _ in labels):
#        plt.legend()
    
    inax = plt.axes([.4, .4, .4, .4])
    
    motor1 = res1[-1]['motor']
    motor2 = res2[-1]['motor']
    nsteps = np.shape(target)[1]
    times = motor1.t[:nsteps]
    inax.plot(times, target[0], ':k', lw=4, label='target')
    
    inax.spines['right'].set_color('none')
    inax.spines['top'].set_color('none')
    
    inax.plot(times, motor1.out[0, :nsteps], c=colors[0], label=labels[0])
    inax.plot(times, motor2.out[0, :nsteps], c=colors[1], label=labels[1])
    
    inax.set_xlabel('time')
    inax.set_ylabel('output')
    
    inax.set_xticks([])
    inax.set_yticks([])
    
    if any(_ is not None for _ in labels):
        inax.legend(bbox_to_anchor=inset_legend_pos, fontsize=inset_label_size)
    
    if inset_ymax is None:
        inax.set_ylim(0, inax.get_ylim()[1])
    else:
        inax.set_ylim(0, inset_ymax)

Tutor-student mismatch heatmap and convergence map

Mismatch 'sum' controller


In [ ]:
file_name = 'save/rate_based_results_sum_log_8.pkl'
with open(file_name, 'rb') as inp:
    res_mismatch = pickle.load(inp)['res_mismatch']

In [ ]:
make_heatmap_plot(res_mismatch)
safe_save_fig('figs/ratebased_mismatch_heatmap_sum_log_8', png=False)

In [ ]:
make_convergence_map(res_mismatch)
safe_save_fig('figs/ratebased_mismatch_convmap_sum_log_8', png=False)

Mismatch 'sum' controller, bigger matrix


In [ ]:
file_name = 'save/rate_based_results_sum_log_12.pkl'
with open(file_name, 'rb') as inp:
    res_mismatch = pickle.load(inp)['res_mismatch']

In [ ]:
make_heatmap_plot(res_mismatch, vmin=0.5, vmax=10)
safe_save_fig('figs/ratebased_mismatch_heatmap_sum_log_12', png=False)

In [ ]:
make_convergence_map(res_mismatch)
safe_save_fig('figs/ratebased_mismatch_convmap_sum_log_12', png=False)

In [ ]:
error_matrix = np.asarray([[_['error_trace'][-1] for _ in crt_res] for crt_res in res_mismatch])
error_matrix[~np.isfinite(error_matrix)] = np.inf
tau_levels = np.asarray([_['params']['tutor_rule_tau'] for _ in res_mismatch[0]])

In [ ]:
plt.semilogx(tau_levels, np.diag(error_matrix), '.-k')

Mismatch 'pushpull' controller


In [ ]:
file_name = 'save/rate_based_results_pushpull_log_8.pkl'
with open(file_name, 'rb') as inp:
    res_mismatch = pickle.load(inp)['res_mismatch']

In [ ]:
make_heatmap_plot(res_mismatch)
safe_save_fig('figs/ratebased_mismatch_heatmap_pushpull_log_8', png=False)

In [ ]:
make_convergence_map(res_mismatch)
safe_save_fig('figs/ratebased_mismatch_convmap_pushpull_log_8', png=False)

Mismatch positive weights


In [ ]:
file_name = 'save/rate_based_results_posweights_log_8.pkl'
with open(file_name, 'rb') as inp:
    res_mismatch = pickle.load(inp)['res_mismatch']

In [ ]:
make_heatmap_plot(res_mismatch)
safe_save_fig('figs/ratebased_mismatch_heatmap_posweights_log_8', png=False)

In [ ]:
make_convergence_map(res_mismatch)
safe_save_fig('figs/ratebased_mismatch_convmap_posweights_log_8', png=False)

Mismatch small $\tau_{1,2}$


In [ ]:
file_name = 'save/rate_based_results_small_tau_log_8.pkl'
with open(file_name, 'rb') as inp:
    res_mismatch = pickle.load(inp)['res_mismatch']

In [ ]:
make_heatmap_plot(res_mismatch)
safe_save_fig('figs/ratebased_mismatch_heatmap_small_tau_log_8', png=False)

In [ ]:
make_convergence_map(res_mismatch)
safe_save_fig('figs/ratebased_mismatch_convmap_small_tau_log_8', png=False)

Mismatch large $\tau_{1,2}$


In [ ]:
file_name = 'save/rate_based_results_large_tau_log_8.pkl'
with open(file_name, 'rb') as inp:
    res_mismatch = pickle.load(inp)['res_mismatch']

In [ ]:
make_heatmap_plot(res_mismatch)
safe_save_fig('figs/ratebased_mismatch_heatmap_large_tau_log_8', png=False)

In [ ]:
make_convergence_map(res_mismatch)
safe_save_fig('figs/ratebased_mismatch_convmap_large_tau_log_8', png=False)

Convergence plots


In [ ]:
def tracker_generator(simulator, i, n):
    """ Generate some trackers. """
    res = {}
    res['motor'] = simulation.StateMonitor(simulator.motor, 'out')
    
    return res

def snapshot_generator_pre(simulator, i, n):
    """ Generate some pre-run snapshots. """
    res = {}
    res['weights'] = np.copy(simulator.student.Ws[0])
    
    return res

Convergence 'sum' controller

Short timescale (convergence 'sum')


In [ ]:
# keep things arbitrary but reproducible
np.random.seed(32292)

simulator = RateLearningSimulation(target, tmax, dt,
                                   n_conductor=100, n_student_per_output=1,
                                   relaxation=400.0, relaxation_conductor=25.0,
                                   conductor_burst_length=None,
                                   tracker_generator=tracker_generator,
                                   snapshot_generator=snapshot_generator_pre,
                                   controller_mode='sum',
                                   tutor_rule_gain_per_student=0.5,
                                   tutor_rule_compress_rates=False,
                                   cs_weights_scale=200.0, ts_weights=0.01,
                                   plasticity_constrain_positive=False,
                                   plasticity_learning_rate=0.001,
                                   plasticity_taus=(80.0, 40.0),
                                   plasticity_params=(0.0, -1.0),
                                   tutor_rule_tau=40.0)
res = simulator.run(250)

In [ ]:
fig = plt.figure(figsize=(3, 2))
axs = draw_convergence_plot(res, plt.gca(), target_lw=2, extra_traces=[0, 4, 8],
                      extra_colors=[[0.831, 0.333, 0.000, 0.25]], inset=True,
                      alpha=simulator.plasticity.alpha, beta=simulator.plasticity.beta,
                      tau_tutor=simulator.tutor_rule.tau, target=target)

axs[0].set_ylim(0, 15);

safe_save_fig('figs/ratebased_convergence_plot_sum_small_tau', png=False)

In [ ]:
make_convergence_movie('figs/ratebased_convergence_movie_sum_small_tau.mov',
                       res, target, idxs=range(0, 250), length=5.0,
                       ymax=80.0)

Long timescale (convergence 'sum')


In [ ]:
# keep things arbitrary but reproducible
np.random.seed(32290)

simulator = RateLearningSimulation(target, tmax, dt,
                                   n_conductor=100, n_student_per_output=1,
                                   relaxation=400.0, relaxation_conductor=25.0,
                                   conductor_burst_length=None,
                                   tracker_generator=tracker_generator,
                                   snapshot_generator=snapshot_generator_pre,
                                   controller_mode='sum',
                                   tutor_rule_gain_per_student=0.5,
                                   tutor_rule_compress_rates=False,
                                   cs_weights_scale=200.0, ts_weights=0.01,
                                   plasticity_constrain_positive=False,
                                   plasticity_learning_rate=0.001,
                                   plasticity_taus=(80.0, 40.0),
                                   plasticity_params=(7.0, 6.0),
                                   tutor_rule_tau=320.0)
res = simulator.run(250)

In [ ]:
fig = plt.figure(figsize=(3, 2))
axs = draw_convergence_plot(res, plt.gca(), target_lw=2, extra_traces=[0, 4, 8],
                      extra_colors=[[0.831, 0.333, 0.000, 0.25]], inset=True,
                      alpha=simulator.plasticity.alpha, beta=simulator.plasticity.beta,
                      tau_tutor=simulator.tutor_rule.tau, target=target)

axs[0].set_ylim(0, 15);

safe_save_fig('figs/ratebased_convergence_plot_sum_large_tau')

In [ ]:
make_convergence_movie('figs/ratebased_convergence_movie_sum_large_tau.mov',
                       res, target, idxs=range(0, 250), length=5.0,
                       ymax=80.0)

Convergence 'pushpull' controller

Short timescale (convergence 'pushpull')


In [ ]:
# keep things arbitrary but reproducible
np.random.seed(22292)

simulator = RateLearningSimulation(target, tmax, dt,
                                   n_conductor=100, n_student_per_output=1,
                                   relaxation=400.0, relaxation_conductor=25.0,
                                   conductor_burst_length=None,
                                   tracker_generator=tracker_generator,
                                   snapshot_generator=snapshot_generator_pre,
                                   controller_mode='pushpull',
                                   tutor_rule_gain_per_student=0.5,
                                   tutor_rule_compress_rates=False,
                                   cs_weights_scale=200.0, ts_weights=0.01,
                                   plasticity_constrain_positive=False,
                                   plasticity_learning_rate=0.001,
                                   plasticity_taus=(80.0, 40.0),
                                   plasticity_params=(0.0, -1.0),
                                   tutor_rule_tau=40.0)
res = simulator.run(250)

In [ ]:
fig = plt.figure(figsize=(3, 2))
axs = draw_convergence_plot(res, plt.gca(), target_lw=2, extra_traces=[0, 4, 8],
                      extra_colors=[[0.831, 0.333, 0.000, 0.25]], inset=True,
                      alpha=simulator.plasticity.alpha, beta=simulator.plasticity.beta,
                      tau_tutor=simulator.tutor_rule.tau, target=target)

axs[0].set_ylim(0, 15);

safe_save_fig('figs/ratebased_convergence_plot_pushpull_small_tau')

In [ ]:
make_convergence_movie('figs/ratebased_convergence_movie_pushpull_small_tau.mov',
                       res, target, idxs=range(0, 250), length=5.0,
                       ymax=80.0)

Long timescale (convergence 'pushpull')


In [ ]:
# keep things arbitrary but reproducible
np.random.seed(32290)

simulator = RateLearningSimulation(target, tmax, dt,
                                   n_conductor=100, n_student_per_output=1,
                                   relaxation=400.0, relaxation_conductor=25.0,
                                   conductor_burst_length=None,
                                   tracker_generator=tracker_generator,
                                   snapshot_generator=snapshot_generator_pre,
                                   controller_mode='pushpull',
                                   tutor_rule_gain_per_student=0.5,
                                   tutor_rule_compress_rates=False,
                                   cs_weights_scale=200.0, ts_weights=0.01,
                                   plasticity_constrain_positive=False,
                                   plasticity_learning_rate=0.001,
                                   plasticity_taus=(80.0, 40.0),
                                   plasticity_params=(7.0, 6.0),
                                   tutor_rule_tau=320.0)
res = simulator.run(250)

In [ ]:
fig = plt.figure(figsize=(3, 2))
axs = draw_convergence_plot(res, plt.gca(), target_lw=2, extra_traces=[0, 4, 8],
                      extra_colors=[[0.831, 0.333, 0.000, 0.25]], inset=True,
                      inset_pos=[0.45, 0.45, 0.4, 0.4],
                      alpha=simulator.plasticity.alpha, beta=simulator.plasticity.beta,
                      tau_tutor=simulator.tutor_rule.tau, target=target)

axs[0].set_ylim(0, 15);

safe_save_fig('figs/ratebased_convergence_plot_pushpull_large_tau')

In [ ]:
make_convergence_movie('figs/ratebased_convergence_movie_pushpull_large_tau.mov',
                       res, target, idxs=range(0, 250), length=5.0,
                       ymax=80.0)

Convergence comparison 'sum' vs. 'pushpull'

'Sum' vs 'pushpull' short timescale


In [ ]:
# keep things arbitrary but reproducible
np.random.seed(212994)

simulator = RateLearningSimulation(target, tmax, dt,
                                   n_conductor=100, n_student_per_output=1,
                                   relaxation=400.0, relaxation_conductor=25.0,
                                   conductor_burst_length=None,
                                   tracker_generator=tracker_generator,
                                   snapshot_generator=snapshot_generator_pre,
                                   controller_mode='sum',
                                   tutor_rule_gain_per_student=0.5,
                                   tutor_rule_compress_rates=False,
                                   cs_weights_scale=200.0, ts_weights=0.01,
                                   plasticity_constrain_positive=False,
                                   plasticity_learning_rate=0.001,
                                   plasticity_taus=(80.0, 40.0),
                                   plasticity_params=(0.0, -1.0),
                                   tutor_rule_tau=40.0)
res_sum = simulator.run(250)

simulator = RateLearningSimulation(target, tmax, dt,
                                   n_conductor=100, n_student_per_output=1,
                                   relaxation=400.0, relaxation_conductor=25.0,
                                   conductor_burst_length=None,
                                   tracker_generator=tracker_generator,
                                   snapshot_generator=snapshot_generator_pre,
                                   controller_mode='pushpull',
                                   tutor_rule_gain_per_student=0.5,
                                   tutor_rule_compress_rates=False,
                                   cs_weights_scale=200.0, ts_weights=0.01,
                                   plasticity_constrain_positive=False,
                                   plasticity_learning_rate=0.001,
                                   plasticity_taus=(80.0, 40.0),
                                   plasticity_params=(0.0, -1.0),
                                   tutor_rule_tau=40.0)
res_pushpull = simulator.run(250)

In [ ]:
plt.figure(figsize=(6, 4))
compare_traces(res_sum, res_pushpull, target, labels=['sum', 'push-pull'], ymax=20, inset_ymax=80)

safe_save_fig('figs/ratebased_sum_vs_pushpull_short_tau', png=False)

'Sum' vs 'pushpull' long timescale


In [ ]:
# keep things arbitrary but reproducible
np.random.seed(202094)

simulator = RateLearningSimulation(target, tmax, dt,
                                   n_conductor=100, n_student_per_output=1,
                                   relaxation=400.0, relaxation_conductor=25.0,
                                   conductor_burst_length=None,
                                   tracker_generator=tracker_generator,
                                   snapshot_generator=snapshot_generator_pre,
                                   controller_mode='sum',
                                   tutor_rule_gain_per_student=0.5,
                                   tutor_rule_compress_rates=False,
                                   cs_weights_scale=200.0, ts_weights=0.01,
                                   plasticity_constrain_positive=False,
                                   plasticity_learning_rate=0.001,
                                   plasticity_taus=(80.0, 40.0),
                                   plasticity_params=(7.0, 6.0),
                                   tutor_rule_tau=320.0)
res_sum = simulator.run(250)

simulator = RateLearningSimulation(target, tmax, dt,
                                   n_conductor=100, n_student_per_output=1,
                                   relaxation=400.0, relaxation_conductor=25.0,
                                   conductor_burst_length=None,
                                   tracker_generator=tracker_generator,
                                   snapshot_generator=snapshot_generator_pre,
                                   controller_mode='pushpull',
                                   tutor_rule_gain_per_student=0.5,
                                   tutor_rule_compress_rates=False,
                                   cs_weights_scale=200.0, ts_weights=0.01,
                                   plasticity_constrain_positive=False,
                                   plasticity_learning_rate=0.001,
                                   plasticity_taus=(80.0, 40.0),
                                   plasticity_params=(7.0, 6.0),
                                   tutor_rule_tau=320.0)
res_pushpull = simulator.run(250)

In [ ]:
plt.figure(figsize=(6, 4))
compare_traces(res_sum, res_pushpull, target, labels=['sum', 'push-pull'], ymax=20, inset_ymax=80)

safe_save_fig('figs/ratebased_sum_vs_pushpull_long_tau', png=False)

Effect of firing rate constraint

Firing rate constraint short timescale


In [ ]:
# keep things arbitrary but reproducible
np.random.seed(8735)

simulator = RateLearningSimulation(target, tmax, dt,
                                   n_conductor=100, n_student_per_output=1,
                                   relaxation=400.0, relaxation_conductor=25.0,
                                   conductor_burst_length=None,
                                   tracker_generator=tracker_generator,
                                   snapshot_generator=snapshot_generator_pre,
                                   controller_mode='sum',
                                   tutor_rule_gain_per_student=0.5,
                                   tutor_rule_compress_rates=False,
                                   cs_weights_scale=200.0, ts_weights=0.01,
                                   plasticity_constrain_positive=False,
                                   plasticity_learning_rate=0.001,
                                   plasticity_taus=(80.0, 40.0),
                                   plasticity_params=(0.0, -1.0),
                                   tutor_rule_tau=40.0)
res = simulator.run(250)

simulator = RateLearningSimulation(target, tmax, dt,
                                   n_conductor=100, n_student_per_output=1,
                                   relaxation=400.0, relaxation_conductor=25.0,
                                   conductor_burst_length=None,
                                   tracker_generator=tracker_generator,
                                   snapshot_generator=snapshot_generator_pre,
                                   controller_mode='sum',
                                   tutor_rule_gain_per_student=0.5,
                                   tutor_rule_compress_rates=True,
                                   cs_weights_scale=200.0, ts_weights=0.01,
                                   plasticity_constrain_positive=False,
                                   plasticity_learning_rate=0.001,
                                   plasticity_taus=(80.0, 40.0),
                                   plasticity_params=(0.0, -1.0),
                                   tutor_rule_tau=40.0)
res_fc = simulator.run(250)

In [ ]:
plt.figure(figsize=(3, 2))
compare_traces(res, res_fc, target, labels=['no constraint', 'with constraint'], ymax=20, inset_ymax=80,
               inset_label_size=6, inset_legend_pos=(1.2, 1.1))

#safe_save_fig('figs/ratebased_rate_constraint_short_tau', png=False)

In [ ]:
plt.figure(figsize=(4, 3))
compare_traces(res, res_fc, target, labels=['no constraint', 'with constraint'], ymax=20, inset_ymax=80)

safe_save_fig('figs/ratebased_rate_constraint_short_tau_square', png=False)

In [ ]:
make_convergence_movie('figs/ratebased_rate_constraint_movie_short_tau.mov',
                       (res, res_fc), target, idxs=range(0, 250), length=5.0,
                       ymax=80.0, labels=['no constraint', 'with constraint'])

Firing rate constraint long timescale


In [ ]:
# keep things arbitrary but reproducible
np.random.seed(202094)

simulator = RateLearningSimulation(target, tmax, dt,
                                   n_conductor=100, n_student_per_output=1,
                                   relaxation=400.0, relaxation_conductor=25.0,
                                   conductor_burst_length=None,
                                   tracker_generator=tracker_generator,
                                   snapshot_generator=snapshot_generator_pre,
                                   controller_mode='sum',
                                   tutor_rule_gain_per_student=0.5,
                                   tutor_rule_compress_rates=False,
                                   cs_weights_scale=200.0, ts_weights=0.01,
                                   plasticity_constrain_positive=False,
                                   plasticity_learning_rate=0.001,
                                   plasticity_taus=(80.0, 40.0),
                                   plasticity_params=(7.0, 6.0),
                                   tutor_rule_tau=320.0)
res = simulator.run(250)

simulator = RateLearningSimulation(target, tmax, dt,
                                   n_conductor=100, n_student_per_output=1,
                                   relaxation=400.0, relaxation_conductor=25.0,
                                   conductor_burst_length=None,
                                   tracker_generator=tracker_generator,
                                   snapshot_generator=snapshot_generator_pre,
                                   controller_mode='sum',
                                   tutor_rule_gain_per_student=0.5,
                                   tutor_rule_compress_rates=True,
                                   cs_weights_scale=200.0, ts_weights=0.01,
                                   plasticity_constrain_positive=False,
                                   plasticity_learning_rate=0.001,
                                   plasticity_taus=(80.0, 40.0),
                                   plasticity_params=(7.0, 6.0),
                                   tutor_rule_tau=320.0)
res_fc = simulator.run(250)

In [ ]:
plt.figure(figsize=(3, 2))
compare_traces(res, res_fc, target, labels=['no constraint', 'with constraint'], ymax=20, inset_ymax=80,
               inset_label_size=6, inset_legend_pos=(1.2, 1.1))

safe_save_fig('figs/ratebased_rate_constraint_long_tau', png=False)

Stepwise learning with 'pushpull' controller


In [ ]:
# keep things arbitrary but reproducible
np.random.seed(2294)

simulator = RateLearningSimulation(target, tmax, dt,
                                   n_conductor=100, n_student_per_output=1,
                                   relaxation=400.0, relaxation_conductor=25.0,
                                   conductor_burst_length=None,
                                   tracker_generator=tracker_generator,
                                   snapshot_generator=snapshot_generator_pre,
                                   controller_mode='pushpull',
                                   tutor_rule_gain_per_student=0.5,
                                   tutor_rule_compress_rates=True,
                                   cs_weights_scale=200.0, ts_weights=0.01,
                                   plasticity_constrain_positive=False,
                                   plasticity_learning_rate=0.001,
                                   plasticity_taus=(80.0, 40.0),
                                   plasticity_params=(24.0, 23.0),
                                   tutor_rule_tau=1000.0)
res_fc = simulator.run(250)

In [ ]:
fig, axs = plt.subplots(1, 4, figsize=(6, 1.75))
show_program_development(res_fc, axs, stages=[0, 16, 33, 249], ymax=80, target=target)
fig.tight_layout()

safe_save_fig('figs/ratebased_rate_constraint_pushpull_stepwise', png=False)

Stepwise learning


In [ ]:
# keep things arbitrary but reproducible
np.random.seed(202094)

simulator = RateLearningSimulation(target, tmax, dt,
                                   n_conductor=100, n_student_per_output=1,
                                   relaxation=400.0, relaxation_conductor=25.0,
                                   conductor_burst_length=None,
                                   tracker_generator=tracker_generator,
                                   snapshot_generator=snapshot_generator_pre,
                                   controller_mode='sum',
                                   tutor_rule_gain_per_student=0.5,
                                   tutor_rule_compress_rates=True,
                                   cs_weights_scale=200.0, ts_weights=0.01,
                                   plasticity_constrain_positive=False,
                                   plasticity_learning_rate=0.001,
                                   plasticity_taus=(80.0, 40.0),
                                   plasticity_params=(24.0, 23.0),
                                   tutor_rule_tau=1000.0)
res_fc = simulator.run(250)

In [ ]:
fig, axs = plt.subplots(1, 4, figsize=(6, 1.75))
show_program_development(res_fc, axs, stages=[0, 16, 33, 249], ymax=80, target=target)
fig.tight_layout()

safe_save_fig('figs/ratebased_rate_constraint_stepwise', png=False)

In [ ]:
fig, axs = plt.subplots(3, 1, figsize=(2, 4))
show_program_development(res_fc, axs, stages=[0, 16, 33], ymax=80, target=target, bbox_pos=(1.15, 1.15))
fig.tight_layout()

safe_save_fig('figs/ratebased_rate_constraint_stepwise_vertical', png=False)

In [ ]:
make_convergence_movie('figs/ratebased_rate_constraint_stepwise_movie.mov',
                       res_fc, target, idxs=range(0, 200), length=10.0,
                       ymax=80.0)

Credit misassignment figures

Credit misassignment figures definitions


In [ ]:
def make_credit_plot(res_mismatch):
    rho_levels = [_['params']['controller_mismatch_amount'] for _ in res_mismatch]
    final_error = [_['error_trace'][-1] for _ in res_mismatch]
    
    plt.figure(figsize=(2.5, 2.25))
    plt.plot(rho_levels, final_error, '.-', color=[0.200, 0.357, 0.400])
    plt.grid(True)
    plt.xlabel('fraction mismatch')
    plt.ylabel('final error');

Credit mismatch figures for 'sum' controller


In [ ]:
file_name = 'save/rate_based_credit_results_sum.pkl'
with open(file_name, 'rb') as inp:
    inp_dict = pickle.load(inp)
    res_mismatch = inp_dict['res_mismatch']
    target = inp_dict['target']
    tmax = inp_dict['tmax']
    dt = inp_dict['dt']

In [ ]:
make_credit_plot(res_mismatch)
plt.xlim(0, 0.5)
plt.ylim(0, 10)

safe_save_fig('figs/ratebased_credit_mismatch_sum', png=False)

Credit mismatch figures for 'pushpull' controller with no subdivision


In [ ]:
file_name = 'save/rate_based_credit_results_pushpull.pkl'
with open(file_name, 'rb') as inp:
    inp_dict = pickle.load(inp)
    res_mismatch = inp_dict['res_mismatch']
    target = inp_dict['target']
    tmax = inp_dict['tmax']
    dt = inp_dict['dt']

In [ ]:
make_credit_plot(res_mismatch)
plt.ylim(0, 10)

safe_save_fig('figs/ratebased_credit_mismatch_pushpull', png=False)

Credit mismatch figures for 'pushpull' controller with subdivision by 2


In [ ]:
file_name = 'save/rate_based_credit_results_pushpull_subdiv.pkl'
with open(file_name, 'rb') as inp:
    inp_dict = pickle.load(inp)
    res_mismatch = inp_dict['res_mismatch']
    target = inp_dict['target']
    tmax = inp_dict['tmax']
    dt = inp_dict['dt']

In [ ]:
make_credit_plot(res_mismatch)
plt.ylim(0, 10)

safe_save_fig('figs/ratebased_credit_mismatch_pushpull_subdiv', png=False)

Non-HVC-like conductor

Here we run some simulations in which the conductor can fire arbitrary patterns, and is not restricted to HVC-like firing in which each neuron fires a single burst during the whole duration of the motor program.


In [ ]:
def tracker_generator(simulator, i, n):
    """ Generate some trackers. """
    res = {}
    if i < 10 or i%10 == 0 or n - i <= 10:
        res['motor'] = simulation.StateMonitor(simulator.motor, 'out')
        res['tutor'] = simulation.StateMonitor(simulator.tutor_rule, 'out')
        res['conductor'] = simulation.StateMonitor(simulator.conductor, 'out')
        res['student'] = simulation.StateMonitor(simulator.student, 'out')
    
    return res

def snapshot_generator_pre(simulator, i, n):
    """ Generate some pre-run snapshots. """
    res = {}
    if i < 10 or i%10 == 0 or n - i <= 10:
        res['weights'] = np.copy(simulator.student.Ws[0])
    
    return res

In [ ]:
# keep things arbitrary but reproducible
np.random.seed(32234)

# average fraction of conductor neurons active at any given time
conductor_sparsity = 0.1
# size of burst
conductor_timescale = 30.0 # ms

n_conductor = 100

# generate a conductor activity pattern
conductor_steps = int_r(tmax/conductor_timescale)
conductor_pattern = np.zeros((n_conductor, conductor_steps))
conductor_pattern[np.random.rand(*conductor_pattern.shape) < conductor_sparsity] = 1.0/(
    conductor_sparsity * float(n_conductor))
conductor_pattern = np.repeat(conductor_pattern, int_r(conductor_timescale/dt), axis=1)

simulator = RateLearningSimulation(target, tmax, dt,
                                   n_conductor=n_conductor, n_student_per_output=1,
                                   relaxation=400.0, relaxation_conductor=25.0,
                                   conductor_burst_length=dt,
                                   conductor_from_table=conductor_pattern,
                                   tracker_generator=tracker_generator,
                                   snapshot_generator=snapshot_generator_pre,
                                   controller_mode='pushpull',
                                   tutor_rule_gain_per_student=0.5,
                                   tutor_rule_compress_rates=False,
                                   cs_weights_scale=200.0, ts_weights=0.01,
                                   plasticity_constrain_positive=True,
                                   plasticity_learning_rate=0.001,
                                   plasticity_taus=(80.0, 40.0),
                                   plasticity_params=(5.0, 4.0),
                                   tutor_rule_tau=240.0)
res = simulator.run(5000)

In [ ]:
plt.figure(figsize=(3, 2))
idx = 0
sel_cond_out = res[idx]['conductor'].out[::5]
draw_multi_traces(res[idx]['conductor'].t, sel_cond_out,
                  color_fct=lambda i: ('k', [0.200, 0.357, 0.400])[i%2], edge_factor=1.4,
                  fill_alpha=0.5)
plt.xlim(0, tmax);
plt.yticks(plt.yticks()[0][::2], range(1, len(sel_cond_out), 2))

plt.xlabel('time (ms)')
plt.ylabel('conductor neuron index')

safe_save_fig('figs/ratebased_alt_conductor_pattern', png=False)

In [ ]:
fig = plt.figure(figsize=(3, 2))
axs = draw_convergence_plot(res[:3001], plt.gca(), target_lw=2,
                      extra_colors=[[0.831, 0.333, 0.000, 0.25]], inset=True,
                      alpha=simulator.plasticity.alpha, beta=simulator.plasticity.beta,
                      tau_tutor=simulator.tutor_rule.tau, target=target,
                      inset_pos=[0.43, 0.4, 0.45, 0.45])

axs[0].set_ylim(0, 15);

safe_save_fig('figs/ratebased_alt_conductor_convergence', png=False)

In [ ]:
make_convergence_movie('figs/ratebased_convergence_movie_alt_conductor.mov',
                       res, target, idxs=range(0, 3000), length=12,
                       ymax=80.0)

Non-linear (but monotonic) student--output relation


In [ ]:
def tracker_generator(simulator, i, n):
    """ Generate some trackers. """
    res = {}
    res['motor'] = simulation.StateMonitor(simulator.motor, 'out')
    res['tutor'] = simulation.StateMonitor(simulator.tutor_rule, 'out')
    res['conductor'] = simulation.StateMonitor(simulator.conductor, 'out')
    res['student'] = simulation.StateMonitor(simulator.student, 'out')
    
    return res

def snapshot_generator_pre(simulator, i, n):
    """ Generate some pre-run snapshots. """
    res = {}
    res['weights'] = np.copy(simulator.student.Ws[0])
    
    return res

In [ ]:
# keep things arbitrary but reproducible
np.random.seed(32292)

simulator = RateLearningSimulation(target, tmax, dt,
                                   n_conductor=100, n_student_per_output=1,
                                   relaxation=400.0, relaxation_conductor=25.0,
                                   conductor_burst_length=None,
                                   tracker_generator=tracker_generator,
                                   snapshot_generator=snapshot_generator_pre,
                                   controller_mode='sum',
                                   controller_nonlinearity=lambda v:
                                           100.0*np.exp((v-100.0)/50.0)/(1 + np.exp((v-100.0)/50.0)),
                                   tutor_rule_gain_per_student=0.5,
                                   tutor_rule_compress_rates=False,
                                   cs_weights_scale=200.0, ts_weights=0.01,
                                   plasticity_constrain_positive=False,
                                   plasticity_learning_rate=0.001,
                                   plasticity_taus=(80.0, 40.0),
                                   plasticity_params=(24.0, 23.0),
                                   tutor_rule_tau=1000.0)
res = simulator.run(500)

In [ ]:
plot_evolution(res, target, dt)

In [ ]:
fig = plt.figure(figsize=(3, 2))
axs = draw_convergence_plot(res, plt.gca(), target_lw=2, extra_traces=[0, 10, 20],
                      extra_colors=[[0.831, 0.333, 0.000, 0.25]], inset=True,
                      alpha=simulator.plasticity.alpha, beta=simulator.plasticity.beta,
                      tau_tutor=simulator.tutor_rule.tau, target=target,
                            inset_pos=[0.4, 0.4, 0.45, 0.45])

axs[0].set_ylim(0, 15);

safe_save_fig('figs/ratebased_convergence_plot_sigmoidal_controller', png=False)

In [ ]:
make_convergence_movie('figs/ratebased_convergence_movie_sigmoidal_controller.mov',
                       res, target, idxs=range(0, 500), length=10.0,
                       ymax=80.0)

Convergence with different kernels


In [ ]:
def tracker_generator(simulator, i, n):
    """ Generate some trackers. """
    res = {}
    res['motor'] = simulation.StateMonitor(simulator.motor, 'out')
    res['tutor'] = simulation.StateMonitor(simulator.tutor_rule, 'out')
    res['conductor'] = simulation.StateMonitor(simulator.conductor, 'out')
    res['student'] = simulation.StateMonitor(simulator.student, 'out')
    
    return res

def snapshot_generator_pre(simulator, i, n):
    """ Generate some pre-run snapshots. """
    res = {}
    res['weights'] = np.copy(simulator.student.Ws[0])
    
    return res

In [ ]:
# keep things arbitrary but reproducible
np.random.seed(32293)

simulator = RateLearningSimulation(target, tmax, dt,
                                   n_conductor=100, n_student_per_output=1,
                                   relaxation=400.0, relaxation_conductor=25.0,
                                   conductor_burst_length=None,
                                   tracker_generator=tracker_generator,
                                   snapshot_generator=snapshot_generator_pre,
#                                   controller_mode='pushpull',
                                   controller_mode='sum',
                                   tutor_rule_gain_per_student=0.5,
                                   tutor_rule_compress_rates=False,
                                   cs_weights_scale=200.0, ts_weights=0.01,
                                   plasticity_constrain_positive=False,
#                                   plasticity_learning_rate=0.002,
                                   plasticity_type='exp_texp',
                                   plasticity_learning_rate=0.001,
                                   plasticity_taus=(40.0, 40.0),
                                   plasticity_params=(24.0, 23.0),
                                   tutor_rule_tau=1000.0)
res = simulator.run(200)

In [ ]:
fig = plt.figure(figsize=(3, 2))
axs = draw_convergence_plot(res, plt.gca(), target_lw=2, #extra_traces=[0, 4, 8],
                      extra_colors=[[0.831, 0.333, 0.000, 0.25]], inset=True,
                      alpha=simulator.plasticity.alpha, beta=simulator.plasticity.beta,
                      tau_tutor=simulator.tutor_rule.tau, target=target)

axs[0].set_ylim(0, 15);

safe_save_fig('figs/ratebased_convergence_plot_alt_kernel', png=False)

In [ ]:
make_convergence_movie('figs/ratebased_convergence_movie_alt_kernel.mov',
                       res, target, idxs=range(0, 200), length=4.0,
                       ymax=80.0)

In [ ]: