MPLP: Sinusoid 5 step, 10 shot

Copyright 2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.


In [ ]:
import os
import io
import numpy as np
import glob

import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()
tf.get_logger().setLevel('ERROR')

import matplotlib.pyplot as plt # visualization

from collections import defaultdict
import random

import itertools
import tensorflow_datasets as tfds
import tensorflow.compat.v2 as tf
import matplotlib.pyplot as plt

import IPython.display as display
from IPython.display import clear_output

from PIL import Image
import numpy as np
import os

In [ ]:
!pip install --upgrade -e git+https://github.com/google-research/self-organising-systems.git#egg=mplp&subdirectory=mplp

In [ ]:
# symlink for saved models.
!ln -s src/mplp/mplp/savedmodels savedmodels

In [ ]:
from mplp.tf_layers import MPDense
from mplp.tf_layers import MPActivation
from mplp.tf_layers import MPSoftmax
from mplp.tf_layers import MPL2Loss
from mplp.tf_layers import MPNetwork
from mplp.sinusoidals import SinusoidalsDS
from mplp.util import SamplePool
from mplp.training import TrainingRegime

The task is to fit sinusoidals from randomly initialized networks.

Therefore, there are:

  • Outer batch size = 4 number of tasks at every step. Each has a different network, different amplitude and different phase.
  • Inner batch size = 10 number of examples for each forward/backward steps.
  • num steps = 5, number of inner steps the network has to get better.
  • train/eval split: the network only sees train instances during forward/backward. The meta-learning regime may choose to use eval splits as well, MAML-style.

In [ ]:
# @title create dataset and plot it
OUTER_BATCH_SIZE = 4
INNER_BATCH_SIZE = 10
NUM_STEPS = 5

ds_factory = SinusoidalsDS()

ds = ds_factory.create_ds(OUTER_BATCH_SIZE, INNER_BATCH_SIZE, NUM_STEPS)
ds_iter = iter(ds)

# Utility range
xrange_inputs = np.linspace(-5,5,100).reshape((100, 1)).astype(np.float32)

xtb, ytb, xeb, yeb = next(ds_iter)
plt.figure(figsize=(14, 10))
colors = itertools.cycle(plt.rcParams['axes.prop_cycle'].by_key()['color'])
for xts, yts, xes, yes in zip(xtb, ytb, xeb, yeb):
  c_t = next(colors)
  c_e = next(colors)
  markers = itertools.cycle((',', '+', '.', 'o', '*')) 
  for xtsib, ytsib, xesib, yesib in zip(xts, yts, xes, yes):
    marker = next(markers)
    plt.scatter(xtsib, ytsib, c=c_t, marker=marker)
    plt.scatter(xesib, yesib, c=c_e, marker=marker)

plt.show()

Create a MP network:


In [ ]:
# This is the size of the message passed.
message_size = 8
stateful = True
stateful_hidden_n = 15

# This network is keras-style initialized.
# If you want to create a single layer, you need to pass it also the in_dim
# and message size.
network = MPNetwork(
    [
     MPDense(20, stateful=stateful, stateful_hidden_n=stateful_hidden_n),
     MPActivation(tf.nn.relu, stateful=stateful, stateful_hidden_n=stateful_hidden_n),
     MPDense(20, stateful=stateful, stateful_hidden_n=stateful_hidden_n), 
     MPActivation(tf.nn.relu, stateful=stateful, stateful_hidden_n=stateful_hidden_n),
     MPDense(1, stateful=stateful, stateful_hidden_n=stateful_hidden_n),
     ],
     MPL2Loss(stateful=stateful, stateful_hidden_n=stateful_hidden_n))
network.setup(in_dim=1, message_size=message_size, inner_batch_size=INNER_BATCH_SIZE)

# see trainable weights:
tr_w = network.get_trainable_weights()
print("trainable weights:")
tot_w = 0
for w in tr_w:
  w = w.numpy()
  w_size = w.size
  tot_w += w_size

  print(w.shape, w_size)
print("tot n:", tot_w)

In [ ]:
def create_new_p_fw():
  return network.init(INNER_BATCH_SIZE)

POOL_SIZE = 128

pool = SamplePool(ps=tf.stack(
    [network.serialize_pfw(create_new_p_fw()) for _ in range(POOL_SIZE)]).numpy())

In [ ]:
num_steps = tf.constant(NUM_STEPS)

learning_schedule = 1e-4

# Prepare a training regime.
# The heldout_weight tells you how to split the loss between train and eval sets
# that are passed to the network.
# Empirically, a heldout_weight=0.0 (or None), results in a much lower overall
# performance, both for train and test losses.
training_regime = TrainingRegime(
    network, heldout_weight=1.0, hint_loss_ratio=0.7, remember_loss_ratio=0.0)

last_step = 0


# minibatch init, allowing to initialize by looking at more
# than just one step.
# Likewise, this can be run multiple times to improve the initialization.
for j in range(1):
  print("on", j)
  stats = []
  pfw = network.init(INNER_BATCH_SIZE)

  x_b, y_b, _, _ = next(ds_iter)
  x_b, y_b = x_b[0], y_b[0]
  for i in range(NUM_STEPS):
    pfw, stats_i = network.minibatch_init(x_b[i],  y_b[i], x_b[i].shape[0], pfw=pfw)
    stats.append(stats_i)
  # update
  network.update_statistics(stats, update_perc=1.)

  print("final mean:")
  for p in tf.nest.flatten(pfw):
    print(p.shape, tf.reduce_mean(p), tf.math.reduce_std(p))


# The outer loop here uses Adam. SGD/Momentum are more stable but way slower.
trainer = tf.keras.optimizers.Adam(learning_schedule)

loss_log = []
def smoothen(l, lookback=20):
  kernel = [1./lookback] * lookback
  return np.convolve(l[0:1] * (lookback - 1) + l, kernel, "valid")

In [ ]:
training_steps = 200000

@tf.function
def step(pfws, xts, yts, xes, yes, num_steps):
  print("compiling")
  with tf.GradientTape() as g:
    l, _, _ = training_regime.batch_mp_loss(pfws, xts, yts, xes, yes, num_steps)
  all_weights = network.get_trainable_weights()
  grads = g.gradient(l, all_weights)
  # Try grad clipping to avoid explosions.
  grads = [g/(tf.norm(g)+1e-8) for g in grads]
  trainer.apply_gradients(zip(grads, all_weights))
  return l


import time
start_time = time.time()

for i in range(last_step + 1, last_step +1 + training_steps):
  last_step = i

  tmp_t = time.time()
  xts, yts, xes, yes = next(ds_iter)

  batch = pool.sample(OUTER_BATCH_SIZE)
  fwps = batch.ps

  l = step(fwps, xts, yts, xes, yes, num_steps)
  loss_log.append(l)

  if i % 50 == 0:
    print(i)
    print("--- %s seconds ---" % (time.time() - start_time))
  if i % 500 == 0:
    plt.plot(smoothen(loss_log, 100), label='mp')
    plt.yscale('log')
    plt.legend()
    plt.show()
print("--- %s seconds ---" % (time.time() - start_time))

In [ ]:
print(loss_log[-1])
plt.plot(smoothen(loss_log, 100), label='mp')
plt.yscale('log')
plt.ylim(0.0, 1e-1)
plt.gca().yaxis.grid(True)
plt.legend()
plt.show()

Proper evaluation: run 100 different few-shot instances with totally new network params.

The train loss is computed only on points that the network has already observed.

The eval loss is computed on the entire range [-5, 5]


In [ ]:
!mkdir tmp

!ls tmp -R

In [ ]:
!mkdir tmp
file_path = "tmp/weights"

network.save_weights(file_path, last_step)
!ls -lh tmp

In [ ]:
# @title Optionally, load saved model
checkpoint_file_path = "savedmodels/sinusoid_weights"
network = MPNetwork(
    [
     MPDense(20, stateful=stateful, stateful_hidden_n=stateful_hidden_n),
     MPActivation(tf.nn.relu, stateful=stateful, stateful_hidden_n=stateful_hidden_n),
     MPDense(20, stateful=stateful, stateful_hidden_n=stateful_hidden_n), 
     MPActivation(tf.nn.relu, stateful=stateful, stateful_hidden_n=stateful_hidden_n),
     MPDense(1, stateful=stateful, stateful_hidden_n=stateful_hidden_n),
     ],
     MPL2Loss(stateful=stateful, stateful_hidden_n=stateful_hidden_n))
network.setup(in_dim=1, message_size=message_size, inner_batch_size=INNER_BATCH_SIZE)
network.load_weights(checkpoint_file_path)

In [ ]:
eval_tot_steps = 100

tr_losses = np.zeros([eval_tot_steps, NUM_STEPS])
ev_losses = np.zeros([eval_tot_steps, NUM_STEPS + 1]) # also 0-step.

@tf.function
def get_loss(pfw, x, y):
  predictions, _= network.forward(pfw, x)
  loss, _ = network.compute_loss(predictions, y)
  return loss

start_time = time.time()

for r in range(eval_tot_steps):
  p_fw = network.init(INNER_BATCH_SIZE)

  A, ph = ds_factory._create_task()

  targets = A * np.sin(xrange_inputs + ph)

  xt, yt = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)

  # initial loss.
  loss = get_loss(p_fw, xrange_inputs, targets)
  ev_losses[r, 0] = tf.reduce_mean(loss)

  for i in range(NUM_STEPS):
    p_fw, _ = network.inner_update(p_fw, xt[i], yt[i])

    # loss specific to only what we observe.
    x_observed_so_far = tf.reshape(xt[:i+1], (-1, 1))
    y_observed_so_far = tf.reshape(yt[:i+1], (-1, 1))
    loss = get_loss(p_fw, x_observed_so_far, y_observed_so_far)
    tr_losses[r, i] = tf.reduce_mean(loss)

    # Plotting for the continuous input range
    loss = get_loss(p_fw, xrange_inputs, targets)
    ev_losses[r, i + 1] = tf.reduce_mean(loss)
print("--- %s seconds ---" % (time.time() - start_time))

tr_losses_m = np.mean(tr_losses, axis=0)
ev_losses_m = np.mean(ev_losses, axis=0)

tr_losses_sd = np.std(tr_losses, axis=0)
ev_losses_sd = np.std(ev_losses, axis=0)

ub = [m + sd for m, sd in zip(tr_losses_m, tr_losses_sd)]
lb = [m - sd for m, sd in zip(tr_losses_m, tr_losses_sd)]
plt.fill_between(range(1, len(tr_losses_m) + 1), ub, lb, alpha=.5)
plt.plot(range(1, len(tr_losses_m) + 1), tr_losses_m, label='train loss')

ub = [m + sd for m, sd in zip(ev_losses_m, ev_losses_sd)]
lb = [m - sd for m, sd in zip(ev_losses_m, ev_losses_sd)]
plt.fill_between(range(0, len(ev_losses_m)), ub, lb, alpha=.5)
plt.plot(range(0, len(ev_losses_m)), ev_losses_m, label='eval loss')
plt.ylim(0.0, 0.025)
plt.xlabel("num steps")
plt.ylabel("L2 loss")
plt.legend()

with open("tmp/mplp_losses.png", "wb") as fout:
  plt.savefig(fout)

In [ ]:
print(tr_losses_m, ev_losses_m)

In [ ]:
# @title Show example runs:

fig, axs = plt.subplots(5, 2, figsize=(10,15))

for fig_n in range(5):
  p_fw = network.init(INNER_BATCH_SIZE)

  n_plot = 5
  plot_every = NUM_STEPS // n_plot if NUM_STEPS >= n_plot else NUM_STEPS

  predictions, _ = network.forward(p_fw, xrange_inputs)

  A, ph = ds_factory._create_task()

  targets = A * np.sin(xrange_inputs + ph)
  axs[fig_n][0].plot(xrange_inputs, targets, label='target')

  predictions, _= network.forward(p_fw, xrange_inputs)
  axs[fig_n][0].plot(xrange_inputs, predictions, label='{}-step predictions'.format(0))

  xt, yt = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)
  xe, ye = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)
  tr_losses = []
  ev_losses = []

  for i in range(NUM_STEPS):
    p_fw, _ = network.inner_update(p_fw, xt[i], yt[i])

    # loss specific to only what we observe.
    x_observed_so_far = tf.reshape(xt[:i+1], (-1, 1))
    y_observed_so_far = tf.reshape(yt[:i+1], (-1, 1))
    predictions, _= network.forward(p_fw, x_observed_so_far)
    loss, _ = network.compute_loss(predictions, y_observed_so_far)
    tr_losses.append(tf.reduce_mean(loss))

    # Plotting for the continuous input range
    predictions, _= network.forward(p_fw, xrange_inputs)
    if (i+1) % plot_every == 0:
      axs[fig_n][0].plot(xrange_inputs, predictions, label='{}-step predictions'.format(i+1))
    loss, _ = network.compute_loss(predictions, targets)
    ev_losses.append(tf.reduce_mean(loss))

  axs[fig_n][1].plot(np.arange(len(tr_losses)), tr_losses, label='tr_losses')
  axs[fig_n][1].plot(np.arange(len(ev_losses)), ev_losses, label='ev_losses')

axs[0][0].legend()
axs[0][1].legend()

In [ ]:
# @title Single run for drawing.

p_fw = network.init(INNER_BATCH_SIZE)

n_plot = 5
plot_every = NUM_STEPS // n_plot if NUM_STEPS >= n_plot else NUM_STEPS

predictions, _= network.forward(p_fw, xrange_inputs)

A, ph = ds_factory._create_task()

targets = A * np.sin(xrange_inputs + ph)
plt.plot(xrange_inputs, targets, label='target')

predictions, _ = network.forward(p_fw, xrange_inputs)
plt.plot(xrange_inputs, predictions, label='{}-step predictions'.format(0))

xt, yt = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)
xe, ye = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)
tr_losses = []
ev_losses = []

for i in range(NUM_STEPS):
  p_fw, _ = network.inner_update(p_fw, xt[i], yt[i])

  # loss specific to only what we observe.
  x_observed_so_far = tf.reshape(xt[:i+1], (-1, 1))
  y_observed_so_far = tf.reshape(yt[:i+1], (-1, 1))
  predictions, _= network.forward(p_fw, x_observed_so_far)
  loss, _ = network.compute_loss(predictions, y_observed_so_far)
  tr_losses.append(tf.reduce_mean(loss))

  # Plotting for the continuous input range
  predictions, _= network.forward(p_fw, xrange_inputs)
  if (i+1) % plot_every == 0:
    plt.plot(xrange_inputs, predictions, label='{}-step predictions'.format(i+1))
  loss, _ = network.compute_loss(predictions, targets)
  ev_losses.append(tf.reduce_mean(loss))


plt.legend()


with open("tmp/mplp_example_run.png", "wb") as fout:
  plt.savefig(fout)

Compare it with an Adam run (the best I managed to make)


In [ ]:
adam_network = MPNetwork(
      [MPDense(20),
       MPActivation(tf.nn.relu),
       MPDense(20), 
       MPActivation(tf.nn.relu),
       MPDense(1),
      ],
      MPL2Loss())
adam_network.setup(in_dim=1, message_size=message_size)

In [ ]:
eval_tot_steps = 100

tr_losses = np.zeros([eval_tot_steps, NUM_STEPS])
ev_losses = np.zeros([eval_tot_steps, NUM_STEPS + 1]) # also 0 step.

for r in range(eval_tot_steps):
  # We need to transform these into variables.
  p_fw = [tf.Variable(t) for t in adam_network.init()]
  adam_trainer = tf.keras.optimizers.Adam(0.01)

  def adam_step(xt, yt):
    with tf.GradientTape() as g:
      g.watch(p_fw)
      y, _ = adam_network.forward(p_fw, xt)
      l, _ = adam_network.compute_loss(y, yt)
      l = tf.reduce_mean(l)

    grads = g.gradient(l, p_fw)
    adam_trainer.apply_gradients(zip(grads, p_fw))

    return l

  A, ph = ds_factory._create_task()

  targets = A * np.sin(xrange_inputs + ph)

  xt, yt = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)
  xe, ye = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)

  # initial loss.
  predictions, _ = adam_network.forward(p_fw, xrange_inputs)
  loss, _ = adam_network.compute_loss(predictions, targets)
  ev_losses[r, 0] = tf.reduce_mean(loss)

  for i in range(NUM_STEPS):
    adam_step(xt[i], yt[i])

    # loss specific to only what we observe.
    x_observed_so_far = tf.reshape(xt[:i+1], (-1, 1))
    y_observed_so_far = tf.reshape(yt[:i+1], (-1, 1))
    predictions, _= adam_network.forward(p_fw, x_observed_so_far)
    loss, _ = adam_network.compute_loss(predictions, y_observed_so_far)
    tr_losses[r, i] = tf.reduce_mean(loss)

    # Plotting for the continuous input range
    predictions, _= adam_network.forward(p_fw, xrange_inputs)
    loss, _ = adam_network.compute_loss(predictions, targets)
    ev_losses[r, i + 1] = tf.reduce_mean(loss)

tr_losses_m = np.mean(tr_losses, axis=0)
ev_losses_m = np.mean(ev_losses, axis=0)

tr_losses_sd = np.std(tr_losses, axis=0)
ev_losses_sd = np.std(ev_losses, axis=0)

ub = [m + sd for m, sd in zip(tr_losses_m, tr_losses_sd)]
lb = [m - sd for m, sd in zip(tr_losses_m, tr_losses_sd)]
plt.fill_between(range(1, len(tr_losses_m) + 1), ub, lb, alpha=.5)
plt.plot(range(1, len(tr_losses_m) + 1), tr_losses_m, label='train loss')

ub = [m + sd for m, sd in zip(ev_losses_m, ev_losses_sd)]
lb = [m - sd for m, sd in zip(ev_losses_m, ev_losses_sd)]
plt.fill_between(range(0, len(ev_losses_m)), ub, lb, alpha=.5)
plt.plot(range(0, len(ev_losses_m)), ev_losses_m, label='eval loss')
plt.ylim(0.0, 0.04)
plt.xlabel("num steps")
plt.ylabel("L2 loss")
plt.legend()

with open("tmp/adam_losses.png", "wb") as fout:
  plt.savefig(fout)

In [ ]:
# Same task as MPLP fo same drawing

p_fw = [tf.Variable(t) for t in adam_network.init()]
adam_trainer = tf.keras.optimizers.Adam(0.01)

n_plot = 5
plot_every = NUM_STEPS // n_plot

def adam_step(xt, yt):
  with tf.GradientTape() as g:
    g.watch(p_fw)
    y, _ = adam_network.forward(p_fw, xt)
    l, _ = adam_network.compute_loss(y, yt)
    l = tf.reduce_mean(l)

  grads = g.gradient(l, p_fw)
  adam_trainer.apply_gradients(zip(grads, p_fw))

  return l

predictions, _ = adam_network.forward(p_fw, xrange_inputs)

A, ph = ds_factory._create_task()

targets = A * np.sin(xrange_inputs + ph)
plt.plot(xrange_inputs, targets, label='target')

predictions, _= adam_network.forward(p_fw, xrange_inputs)
plt.plot(xrange_inputs, predictions, label='{}-step predictions'.format(0))

xt, yt = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)
xe, ye = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)
tr_losses = []
ev_losses = []

for i in range(NUM_STEPS):
  adam_step(xt[i], yt[i])

  # loss specific to only what we observe.
  x_observed_so_far = tf.reshape(xt[:i+1], (-1, 1))
  y_observed_so_far = tf.reshape(yt[:i+1], (-1, 1))
  predictions, _= adam_network.forward(p_fw, x_observed_so_far)
  loss, _ = adam_network.compute_loss(predictions, y_observed_so_far)
  tr_losses.append(tf.reduce_mean(loss))

  # Plotting for the continuous input range
  predictions, _= adam_network.forward(p_fw, xrange_inputs)
  if (i+1) % plot_every == 0:
    plt.plot(xrange_inputs, predictions, label='{}-step predictions'.format(i+1))
  loss, _ = adam_network.compute_loss(predictions, targets)
  ev_losses.append(tf.reduce_mean(loss))

plt.legend()

with open("tmp/adam_example_run.png", "wb") as fout:
  plt.savefig(fout)

In [ ]:
# @title Show an example run:
p_fw = [tf.Variable(t) for t in adam_network.init()]
adam_trainer = tf.keras.optimizers.Adam(0.01)

n_plot = 5
plot_every = NUM_STEPS // n_plot

def adam_step(xt, yt):
  with tf.GradientTape() as g:
    g.watch(p_fw)
    y, _ = adam_network.forward(p_fw, xt)
    l, _ = adam_network.compute_loss(y, yt)
    l = tf.reduce_mean(l)

  grads = g.gradient(l, p_fw)
  adam_trainer.apply_gradients(zip(grads, p_fw))

  return l

predictions, _ = adam_network.forward(p_fw, xrange_inputs)

A, ph = ds_factory._create_task()

targets = A * np.sin(xrange_inputs + ph)
plt.plot(xrange_inputs, targets, label='target')

predictions, _= adam_network.forward(p_fw, xrange_inputs)
plt.plot(xrange_inputs, predictions, label='{}-step predictions'.format(0))

xt, yt = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)
xe, ye = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)
tr_losses = []
ev_losses = []

for i in range(NUM_STEPS):
  adam_step(xt[i], yt[i])

  # loss specific to only what we observe.
  x_observed_so_far = tf.reshape(xt[:i+1], (-1, 1))
  y_observed_so_far = tf.reshape(yt[:i+1], (-1, 1))
  predictions, _ = adam_network.forward(p_fw, x_observed_so_far)
  loss, _ = adam_network.compute_loss(predictions, y_observed_so_far)
  tr_losses.append(tf.reduce_mean(loss))

  # Plotting for the continuous input range
  predictions, _ = adam_network.forward(p_fw, xrange_inputs)
  if (i+1) % plot_every == 0:
    plt.plot(xrange_inputs, predictions, label='{}-step predictions'.format(i+1))
  loss, _ = adam_network.compute_loss(predictions, targets)
  ev_losses.append(tf.reduce_mean(loss))

plt.legend()
plt.show()

plt.plot(np.arange(len(tr_losses)), tr_losses, label='tr_losses')
plt.plot(np.arange(len(ev_losses)), ev_losses, label='ev_losses')
plt.legend()
plt.show()