Contextual Integration RNN tutorial

In this notebook, we train a vanilla RNN to integrate one of two streams of white noise. This example is useful on its own to understand how RNN training works, and how to use JAX. In addition, it provides the input for the LFADS JAX Gaussian Mixture model notebook.

Copyright 2019 Google LLC

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

 https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

Imports


In [1]:
# Numpy, JAX, Matplotlib and h5py should all be correctly installed and on the python path.
from __future__ import print_function, division, absolute_import
import datetime
import h5py
import jax.numpy as np
from jax import random
from jax.experimental import optimizers
import matplotlib.pyplot as plt
import numpy as onp             # original CPU-backed NumPy
import os
import sys
import time

os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'False'
from importlib import reload

In [2]:
# Import the tutorial code.

# You must change this to the location of computation-thru-dynamics directory.
HOME_DIR = '/home/youngjujo/' 

sys.path.append(os.path.join(HOME_DIR,'computation-thru-dynamics/experimental'))
import contextual_integrator_rnn_tutorial.integrator as integrator
import contextual_integrator_rnn_tutorial.rnn as rnn
import contextual_integrator_rnn_tutorial.utils as utils

Hyperparameters


In [3]:
# Integration parameters
T = 1.0          # Arbitrary amount time, roughly physiological.
ntimesteps = 25  # Divide T into this many bins
bval = 0.01      # bias value limit
sval = 0.025     # standard deviation (before dividing by sqrt(dt))
input_params = (bval, sval, T, ntimesteps)

# Integrator RNN hyperparameters
u = 4         # Number of inputs to the RNN
n = 100       # Number of units in the RNN
o = 1         # Number of outputs in the RNN

# The scaling of the recurrent parameters in an RNN really matters. 
# The correct scaling is 1/sqrt(number of recurrent inputs), which 
# yields an order 1 signal output to a neuron if the input is order 1.
# Given that VRNN uses a tanh nonlinearity, with min and max output 
# values of -1 and 1, this works out.  The scaling just below 1 
# (0.95) is because we know we are making a line attractor so, we 
# might as well start it off basically right 1.0 is also basically 
# right, but perhaps will lead to crazier dynamics.
param_scale = 0.8 # Scaling of the recurrent weight matrix

# Optimization hyperparameters
num_batchs = 10000         # Total number of batches to train on.
batch_size = 128          # How many examples in each batch
eval_batch_size = 1024    # How large a batch for evaluating the RNN
step_size = 0.025          # initial learning rate
decay_factor = 0.99975     # decay the learning rate this much
# Gradient clipping is HUGELY important for training RNNs
max_grad_norm = 10.0      # max gradient norm before clipping, clip to this value.
l2reg = 0.0002           # amount of L2 regularization on the weights
adam_b1 = 0.9             # Adam parameters
adam_b2 = 0.999
adam_eps = 1e-1
print_every = 100          # Print training informatino every so often

In [4]:
# JAX handles randomness differently than numpy or matlab. 
# one threads the randomness through to each function. 
#  It's a bit tedious, but very easy to understand and with
# reliable effect.
seed = onp.random.randint(0, 1000000) # get randomness from CPU level numpy
print("Seed: %d" % seed)
key = random.PRNGKey(seed) # create a random key for jax for use on device.

# Plot a few input/target examples to make sure things look sane.
ntoplot = 10    # how many examples to plot
# With this split command, we are always getting a new key from the old key,
# and I use first key as as source of randomness for new keys.
#     key, subkey = random.split(key, 2)
#     ## do something random with subkey
#     key, subkey = random.split(key, 2)
#     ## do something random with subkey
# In this way, the same top level randomness source stays random.

# The number of examples to plot is given by the number of 
# random keys in this function.
key, skey = random.split(key, 2)
skeys = random.split(skey, ntoplot) # get ntoplot random keys
reload(integrator)
inputs, targets = integrator.build_inputs_and_targets(input_params, skeys)

# Plot the input to the RNN and the target for the RNN.
integrator.plot_batch(ntimesteps, inputs, targets, ntoplot=1)


Seed: 218525

In [5]:
# Init some parameters for training.
reload(rnn)
key = random.PRNGKey(onp.random.randint(100000000))
init_params = rnn.random_vrnn_params(key, u, n, o, g=param_scale)
rnn.plot_params(init_params)



In [6]:
# Create a decay function for the learning rate
decay_fun = optimizers.exponential_decay(step_size, decay_steps=1, 
                                         decay_rate=decay_factor)

batch_idxs = onp.linspace(1, num_batchs)
plt.plot(batch_idxs, [decay_fun(b) for b in batch_idxs])
plt.axis('tight')
plt.xlabel('Batch number')
plt.ylabel('Learning rate');


Train the VRNN


In [7]:
reload(rnn)
# Initialize the optimizer.  Please see jax/experimental/optimizers.py
opt_init, opt_update, get_params = optimizers.adam(decay_fun, adam_b1, adam_b2, adam_eps)
opt_state = opt_init(init_params)

# Run the optimization loop, first jit'd call will take a minute.
start_time = time.time()
all_train_losses = []
for batch in range(num_batchs):
    key = random.fold_in(key, batch)
    skeys = random.split(key, batch_size)
    inputs, targets = integrator.build_inputs_and_targets_jit(input_params, skeys)
    opt_state = rnn.update_w_gc_jit(batch, opt_state, opt_update, get_params, inputs,
                                  targets, max_grad_norm, l2reg)
    if batch % print_every == 0:
        params = get_params(opt_state)
        all_train_losses.append(rnn.loss_jit(params, inputs, targets, l2reg))
        train_loss = all_train_losses[-1]['total']
        batch_time = time.time() - start_time
        step_size = decay_fun(batch)
        s = "Batch {} in {:0.2f} sec, step size: {:0.5f}, training loss {:0.4f}"
        print(s.format(batch, batch_time, step_size, train_loss))
        start_time = time.time()
        
# List of dicts to dict of lists
all_train_losses = {k: [dic[k] for dic in all_train_losses] for k in all_train_losses[0]}


Batch 0 in 3.27 sec, step size: 0.02500, training loss 0.7996
Batch 100 in 0.44 sec, step size: 0.02438, training loss 0.2122
Batch 200 in 0.51 sec, step size: 0.02378, training loss 0.1404
Batch 300 in 0.50 sec, step size: 0.02319, training loss 0.0468
Batch 400 in 0.44 sec, step size: 0.02262, training loss 0.0447
Batch 500 in 0.43 sec, step size: 0.02206, training loss 0.0435
Batch 600 in 0.43 sec, step size: 0.02152, training loss 0.0340
Batch 700 in 0.50 sec, step size: 0.02099, training loss 0.0357
Batch 800 in 0.45 sec, step size: 0.02047, training loss 0.0345
Batch 900 in 0.44 sec, step size: 0.01996, training loss 0.0519
Batch 1000 in 0.43 sec, step size: 0.01947, training loss 0.0327
Batch 1100 in 0.44 sec, step size: 0.01899, training loss 0.0315
Batch 1200 in 0.44 sec, step size: 0.01852, training loss 0.0293
Batch 1300 in 0.44 sec, step size: 0.01806, training loss 0.0293
Batch 1400 in 0.45 sec, step size: 0.01762, training loss 0.0305
Batch 1500 in 0.45 sec, step size: 0.01718, training loss 0.0291
Batch 1600 in 0.44 sec, step size: 0.01676, training loss 0.0300
Batch 1700 in 0.44 sec, step size: 0.01634, training loss 0.0328
Batch 1800 in 0.44 sec, step size: 0.01594, training loss 0.0286
Batch 1900 in 0.44 sec, step size: 0.01555, training loss 0.0276
Batch 2000 in 0.45 sec, step size: 0.01516, training loss 0.0270
Batch 2100 in 0.46 sec, step size: 0.01479, training loss 0.0258
Batch 2200 in 0.43 sec, step size: 0.01442, training loss 0.0261
Batch 2300 in 0.44 sec, step size: 0.01407, training loss 0.0253
Batch 2400 in 0.45 sec, step size: 0.01372, training loss 0.0255
Batch 2500 in 0.45 sec, step size: 0.01338, training loss 0.0255
Batch 2600 in 0.45 sec, step size: 0.01305, training loss 0.0255
Batch 2700 in 0.46 sec, step size: 0.01273, training loss 0.0259
Batch 2800 in 0.47 sec, step size: 0.01241, training loss 0.0252
Batch 2900 in 0.45 sec, step size: 0.01211, training loss 0.0242
Batch 3000 in 0.47 sec, step size: 0.01181, training loss 0.0259
Batch 3100 in 0.47 sec, step size: 0.01152, training loss 0.0235
Batch 3200 in 0.45 sec, step size: 0.01123, training loss 0.0239
Batch 3300 in 0.45 sec, step size: 0.01095, training loss 0.0233
Batch 3400 in 0.45 sec, step size: 0.01068, training loss 0.0230
Batch 3500 in 0.45 sec, step size: 0.01042, training loss 0.0228
Batch 3600 in 0.47 sec, step size: 0.01016, training loss 0.0229
Batch 3700 in 0.46 sec, step size: 0.00991, training loss 0.0226
Batch 3800 in 0.46 sec, step size: 0.00967, training loss 0.0225
Batch 3900 in 0.51 sec, step size: 0.00943, training loss 0.0221
Batch 4000 in 0.52 sec, step size: 0.00920, training loss 0.0219
Batch 4100 in 0.45 sec, step size: 0.00897, training loss 0.0223
Batch 4200 in 0.45 sec, step size: 0.00875, training loss 0.0220
Batch 4300 in 0.45 sec, step size: 0.00853, training loss 0.0216
Batch 4400 in 0.45 sec, step size: 0.00832, training loss 0.0217
Batch 4500 in 0.45 sec, step size: 0.00812, training loss 0.0212
Batch 4600 in 0.44 sec, step size: 0.00791, training loss 0.0213
Batch 4700 in 0.49 sec, step size: 0.00772, training loss 0.0216
Batch 4800 in 0.43 sec, step size: 0.00753, training loss 0.0208
Batch 4900 in 0.43 sec, step size: 0.00734, training loss 0.0211
Batch 5000 in 0.43 sec, step size: 0.00716, training loss 0.0206
Batch 5100 in 0.49 sec, step size: 0.00698, training loss 0.0206
Batch 5200 in 0.43 sec, step size: 0.00681, training loss 0.0205
Batch 5300 in 0.43 sec, step size: 0.00664, training loss 0.0204
Batch 5400 in 0.43 sec, step size: 0.00648, training loss 0.0203
Batch 5500 in 0.43 sec, step size: 0.00632, training loss 0.0203
Batch 5600 in 0.43 sec, step size: 0.00616, training loss 0.0202
Batch 5700 in 0.42 sec, step size: 0.00601, training loss 0.0201
Batch 5800 in 0.42 sec, step size: 0.00586, training loss 0.0200
Batch 5900 in 0.43 sec, step size: 0.00572, training loss 0.0198
Batch 6000 in 0.43 sec, step size: 0.00558, training loss 0.0198
Batch 6100 in 0.43 sec, step size: 0.00544, training loss 0.0196
Batch 6200 in 0.43 sec, step size: 0.00531, training loss 0.0196
Batch 6300 in 0.43 sec, step size: 0.00517, training loss 0.0195
Batch 6400 in 0.43 sec, step size: 0.00505, training loss 0.0194
Batch 6500 in 0.43 sec, step size: 0.00492, training loss 0.0194
Batch 6600 in 0.49 sec, step size: 0.00480, training loss 0.0194
Batch 6700 in 0.49 sec, step size: 0.00468, training loss 0.0193
Batch 6800 in 0.45 sec, step size: 0.00457, training loss 0.0193
Batch 6900 in 0.42 sec, step size: 0.00445, training loss 0.0194
Batch 7000 in 0.48 sec, step size: 0.00434, training loss 0.0190
Batch 7100 in 0.49 sec, step size: 0.00424, training loss 0.0190
Batch 7200 in 0.47 sec, step size: 0.00413, training loss 0.0189
Batch 7300 in 0.43 sec, step size: 0.00403, training loss 0.0189
Batch 7400 in 0.42 sec, step size: 0.00393, training loss 0.0188
Batch 7500 in 0.42 sec, step size: 0.00383, training loss 0.0187
Batch 7600 in 0.42 sec, step size: 0.00374, training loss 0.0187
Batch 7700 in 0.43 sec, step size: 0.00365, training loss 0.0187
Batch 7800 in 0.43 sec, step size: 0.00356, training loss 0.0187
Batch 7900 in 0.45 sec, step size: 0.00347, training loss 0.0186
Batch 8000 in 0.50 sec, step size: 0.00338, training loss 0.0186
Batch 8100 in 0.43 sec, step size: 0.00330, training loss 0.0185
Batch 8200 in 0.43 sec, step size: 0.00322, training loss 0.0185
Batch 8300 in 0.43 sec, step size: 0.00314, training loss 0.0184
Batch 8400 in 0.42 sec, step size: 0.00306, training loss 0.0183
Batch 8500 in 0.43 sec, step size: 0.00299, training loss 0.0182
Batch 8600 in 0.43 sec, step size: 0.00291, training loss 0.0182
Batch 8700 in 0.43 sec, step size: 0.00284, training loss 0.0182
Batch 8800 in 0.48 sec, step size: 0.00277, training loss 0.0182
Batch 8900 in 0.47 sec, step size: 0.00270, training loss 0.0181
Batch 9000 in 0.47 sec, step size: 0.00263, training loss 0.0182
Batch 9100 in 0.43 sec, step size: 0.00257, training loss 0.0182
Batch 9200 in 0.43 sec, step size: 0.00251, training loss 0.0180
Batch 9300 in 0.43 sec, step size: 0.00244, training loss 0.0180
Batch 9400 in 0.43 sec, step size: 0.00238, training loss 0.0180
Batch 9500 in 0.43 sec, step size: 0.00232, training loss 0.0179
Batch 9600 in 0.47 sec, step size: 0.00227, training loss 0.0180
Batch 9700 in 0.47 sec, step size: 0.00221, training loss 0.0178
Batch 9800 in 0.43 sec, step size: 0.00216, training loss 0.0179
Batch 9900 in 0.45 sec, step size: 0.00210, training loss 0.0178

In [8]:
# Show the loss through training.
xlims = [2, 50]
plt.figure(figsize=(16,4))
plt.subplot(141)
plt.plot(all_train_losses['total'][xlims[0]:xlims[1]], 'k')
plt.title('Total')

plt.subplot(142)
plt.plot(all_train_losses['lms'][xlims[0]:xlims[1]], 'r')
plt.title('Least mean square')

plt.subplot(143)
plt.plot(all_train_losses['l2'][xlims[0]:xlims[1]], 'g');
plt.title('L2')

plt.subplot(144)
plt.plot(all_train_losses['total'][xlims[0]:xlims[1]], 'k')
plt.plot(all_train_losses['lms'][xlims[0]:xlims[1]], 'r')
plt.plot(all_train_losses['l2'][xlims[0]:xlims[1]], 'g')
plt.title('All losses')


Out[8]:
Text(0.5, 1.0, 'All losses')

Testing


In [9]:
# Take a batch for an evalulation loss, notice the L2 penalty is 0
# for the evaluation.
params = get_params(opt_state)

key, subkey = random.split(key, 2)
skeys = random.split(subkey, batch_size)
inputs, targets = integrator.build_inputs_and_targets_jit(input_params, skeys)
eval_loss = rnn.loss_jit(params, inputs, targets, l2reg=0.0)['total']
eval_loss_str = "{:.5f}".format(eval_loss)
print("Loss on a new large batch: %s" % (eval_loss_str))


Loss on a new large batch: 0.00024

Visualizations of trained system


In [10]:
reload(rnn)

# Visualize how good this trained integrator is
def inputs_targets_no_h0s(keys):
    inputs_b, targets_b = \
        integrator.build_inputs_and_targets_jit(input_params, keys)
    h0s_b = None # Use trained h0
    return inputs_b, targets_b, h0s_b

rnn_run = lambda inputs: rnn.batched_rnn_run(params, inputs)

give_trained_h0 = lambda batch_size : np.array([params['h0']] * batch_size)

rnn_internals = rnn.run_trials(rnn_run, inputs_targets_no_h0s, 1, 16)

integrator.plot_batch(ntimesteps, rnn_internals['inputs'], 
                      rnn_internals['targets'], rnn_internals['outputs'], 
                      onp.abs(rnn_internals['targets'] - rnn_internals['outputs']))



In [11]:
# Visualize the hidden state, as an example.
reload(rnn)
rnn.plot_examples(ntimesteps, rnn_internals, nexamples=4)



In [12]:
# Take a look at the trained parameters.
rnn.plot_params(params)


Saving


In [13]:
# Define directories, etc.
task_type = 'contextual_int'
rnn_type = 'vrnn'
fname_uniquifier = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
data_dir = os.path.join(os.path.join('/tmp', rnn_type), task_type)

print(data_dir)
print(fname_uniquifier)


/tmp/vrnn/contextual_int
2020-06-16_00:14:21

In [14]:
# Save parameters

params_fname = ('trained_params_' + rnn_type + '_' + task_type + '_' + \
                eval_loss_str + '_' + fname_uniquifier + '.h5')
params_fname = os.path.join(data_dir, params_fname)

print("Saving params in %s" % (params_fname))
utils.write_file(params_fname, params)


Saving params in /tmp/vrnn/contextual_int/trained_params_vrnn_contextual_int_0.00024_2020-06-16_00:14:21.h5

Create per-trial initial conditions along the line attractor.

Let's create some per-trial initial conditions by running the integrator for a few time steps. This will make comparing the learned initial states in the LFADS tutorial easier to visualize.


In [15]:
nsave_batches = 20 # Save about 20000 trials

h0_ntimesteps = 30 # First few steps would generate a distribution of initial conditions
h0_input_params = (bval, sval,  
                   T * float(h0_ntimesteps) / float(ntimesteps), 
                   h0_ntimesteps)

In [19]:
def get_h0s_inputs_targets_h0s(keys):
    inputs_bxtxu, targets_bxtxm = \
        integrator.build_inputs_and_targets_jit(h0_input_params, keys)
    h0s = give_trained_h0(len(keys))
    return (inputs_bxtxu, targets_bxtxm, h0s)
    
rnn_run_w_h0 = lambda inputs, h0s: rnn.batched_rnn_run_w_h0(params, inputs, h0s)

data_dict = rnn.run_trials(rnn_run_w_h0, get_h0s_inputs_targets_h0s, 
                              nsave_batches, eval_batch_size)

In [20]:
data_dict['inputs'] = h0_data_dict['inputs'][:,5:,:]
data_dict['hiddens'] = h0_data_dict['hiddens'][:,5:,:]
data_dict['outputs'] = h0_data_dict['targets'][:,5:,:]
data_dict['targets'] = h0_data_dict['targets'][:,5:,:]
data_dict['h0s'] = h0_data_dict['hiddens'][:,5,:]

rnn.plot_examples(ntimesteps, data_dict, nexamples=4)



In [21]:
data_fname = ('trained_data_' + rnn_type + '_' + task_type + '_' + \
              eval_loss_str + '_' + fname_uniquifier + '.h5')
data_fname = os.path.join(data_dir, data_fname)
print("Saving data in %s" %(data_fname))
utils.write_file(data_fname, data_dict)


Saving data in /tmp/vrnn/contextual_int/trained_data_vrnn_contextual_int_0.00024_2020-06-16_00:14:21.h5