MPLP: MNIST 20 step, 8 shot

Copyright 2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.


In [ ]:
import os
import io
import numpy as np
import glob

import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()
tf.get_logger().setLevel('ERROR')

import matplotlib.pyplot as plt # visualization

import numpy as onp
from collections import defaultdict

import itertools
import tensorflow_datasets as tfds
import tensorflow.compat.v2 as tf
import matplotlib.pyplot as plt

import IPython.display as display
from PIL import Image
import numpy as np
import os

In [ ]:
!pip install --upgrade -e git+https://github.com/google-research/self-organising-systems.git#egg=mplp&subdirectory=mplp

In [ ]:
# symlink for saved models.
!ln -s src/mplp/mplp/savedmodels savedmodels

In [ ]:
from mplp.tf_layers import MPDense
from mplp.tf_layers import MPActivation
from mplp.tf_layers import MPSoftmax
from mplp.tf_layers import MPCrossEntropyLoss
from mplp.tf_layers import MPNetwork
from mplp.util import SamplePool
from mplp.training import TrainingRegime
from mplp.core import GRUBlock
from mplp.core import OutStandardizer
from mplp.core import StandardizeInputsAndStates

In [ ]:
def load_mnist():
  (x_train, y_train),(x_test, y_test) = tf.keras.datasets.mnist.load_data()
  x_train, x_test = x_train / 255.0, x_test / 255.0

  x_train = np.array(x_train).astype(np.float32)
  y_train = np.array(y_train).astype(np.int64)
  x_test = np.array(x_test).astype(np.float32)
  y_test = np.array(y_test).astype(np.int64)
  return (x_train, y_train),(x_test, y_test)

(x_train, y_train),(x_test, y_test) = load_mnist()

def one_hottify(dset):
  one_hottified = onp.zeros([dset.shape[0], 10], dtype=onp.float32)
  one_hottified[onp.arange(dset.size), dset] = 1.0
  return one_hottified

def MNIST_generator(inputs, labels):
  while True:
    idx = pyrandom.randrange(len(inputs))
    x = inputs[idx]
    y = labels[idx]
    yield x, y

PIC_L = 12

x_train = tf.image.resize(onp.expand_dims(x_train, -1), size=(PIC_L, PIC_L)).numpy()
x_test = tf.image.resize(onp.expand_dims(x_test, -1), size=(PIC_L, PIC_L)).numpy()
y_train = one_hottify(y_train)
y_test = one_hottify(y_test)

# standardize inputs:
train_mean, train_std = x_train.mean(), x_train.std()

x_train = (x_train - train_mean) / train_std
x_test = (x_test - train_mean) / train_std

x_train = x_train.reshape([-1, PIC_L * PIC_L])
x_test = x_test.reshape([-1, PIC_L * PIC_L])

In [ ]:
outer_batch_size = 1
inner_batch_size = 8
loop_steps = 20
img_h = img_w = 12

def resize(x, y):
  x = tf.image.resize(x, [img_h, img_w])
  x = tf.reshape(x, [-1, img_h * img_w])
  return x, y

resized_mnist_train = tf.data.Dataset.from_tensor_slices(
    (x_train, y_train))#.map(resize)

with tf.device("/cpu:0"):
  dataset = resized_mnist_train.batch(inner_batch_size).batch(loop_steps).batch(outer_batch_size).shuffle(1000).cache().repeat()
  ds_iter = iter(dataset)
  xval_ds = resized_mnist_train.batch(inner_batch_size).batch(outer_batch_size).shuffle(1000).cache().repeat()
  xval_ds_iter = iter(xval_ds)

  test_ds = tf.data.Dataset.from_generator(
            lambda: MNIST_generator(x_test, y_test),
            output_types=(onp.float32, onp.float32))
  test_ds = test_ds.batch(inner_batch_size)
  test_ds_iter = iter(test_ds)

In [ ]:
message_size = 4
stateful = True
stateful_hidden_n = 7

import functools

l0 = PIC_L * PIC_L
def init_network(activation, sizes):
  network = MPNetwork([MPDense(sizes[1], stateful=stateful, stateful_hidden_n=stateful_hidden_n),
                    MPActivation(activation, stateful=stateful, stateful_hidden_n=stateful_hidden_n),
                    MPDense(10, stateful=stateful, stateful_hidden_n=stateful_hidden_n),
                    MPSoftmax(stateful=stateful, stateful_hidden_n=stateful_hidden_n),
                    ],
                    MPCrossEntropyLoss(message_size, stateful=stateful, stateful_hidden_n=stateful_hidden_n))
  # create shared networks!
  shared_params = {
      "W_net": GRUBlock(x_dim=2 + message_size,
                        carry_n=1+message_size+stateful_hidden_n),
      "b_net": GRUBlock(x_dim=2 + message_size,
                        carry_n=1+message_size+stateful_hidden_n),
#      "W_out_std": OutStandardizer(scale_init_val=0.05),
#      "b_out_std": OutStandardizer(scale_init_val=0.05),
#      "W_in_std": StandardizeInputsAndStates(["W_b", "W_in"]),
#      "b_in_std": StandardizeInputsAndStates(["b_b", "b_in"]),
                   }

  network.setup(in_dim=sizes[0], message_size=message_size, 
                inner_batch_size=inner_batch_size,
                shared_params=shared_params)
  return network

network = init_network(tf.sigmoid, (l0, 50))

In [ ]:
def create_new_p_fw():
  return network.init(inner_batch_size)

POOL_SIZE = 16

pool = SamplePool(ps=tf.stack(
    [network.serialize_pfw(create_new_p_fw()) for _ in range(POOL_SIZE)]).numpy())

In [ ]:
loop_steps_t = tf.constant(loop_steps)

learning_schedule = 5e-4

training_regime = TrainingRegime(
    network, heldout_weight=1.0, hint_loss_ratio=0.7, remember_loss_ratio=None)

last_step = 0

# minibatch init, allowing to initialize by looking at more
# than just one step.
# Likewise, this can be run multiple times to improve the initialization.
for j in range(5):
  print("on", j)
  stats = []
  pfw = network.init(inner_batch_size)

  x_b, y_b = next(ds_iter)
  x_b, y_b = x_b[0], y_b[0]
  for i in range(loop_steps):
    pfw, stats_i = network.minibatch_init(x_b[i],  y_b[i], x_b[i].shape[0], pfw=pfw)
    stats.append(stats_i)
  # update
  network.update_statistics(stats, update_perc=0.5)

  print("final mean:")
  for p in tf.nest.flatten(pfw):
    print(p.shape, tf.reduce_mean(p), tf.math.reduce_std(p))


# The outer loop here uses Adam. SGD/Momentum are more stable but way slower.
trainer = tf.keras.optimizers.Adam(learning_schedule)

loss_log = []
def smoothen(l, lookback=20):
  # first of all, if it's a nan, change it to a high value
  kernel = [1./lookback] * lookback
  return np.convolve(l[0:1] * (lookback - 1) + l, kernel, "valid")

In [ ]:
!mkdir tmp

In [ ]:
training_steps = 100000

learning_schedule = 1e-4

@tf.function
def step(pfws, xts, yts, xes, yes, num_steps):
  print("compiling")
  with tf.GradientTape() as g:
    l, _, _ = training_regime.batch_mp_loss(pfws, xts, yts, xes, yes, num_steps)
  all_weights = network.get_trainable_weights()
  grads = g.gradient(l, all_weights)
  # Try grad clipping to avoid explosions.
  grads = [g/(tf.norm(g)+1e-8) for g in grads]
  trainer.apply_gradients(zip(grads, all_weights))
  return l


import time
start_time = time.time()

for i in range(last_step + 1, last_step +1 + training_steps):
  last_step = i

  tmp_t = time.time()
  xts, yts = next(ds_iter)
  xes, yes = next(xval_ds_iter)

  batch = pool.sample(outer_batch_size)
  fwps = batch.ps

  l = step(fwps, xts, yts, xes, yes, loop_steps_t)
  loss_log.append(l)

  if i % 50 == 0:
    print(i)
    print("--- %s seconds ---" % (time.time() - start_time))
  if i % 500 == 0:
    plt.plot(smoothen(loss_log, 100), label='mp')
    plt.yscale('log')
    plt.legend()
    plt.show()
  if i % 2500 == 0:
    file_path = "tmp/weights"
    network.save_weights(file_path, last_step)
print("--- %s seconds ---" % (time.time() - start_time))

np_batched_stl_mpv2_nograd_loss = time_series

In [ ]:
!ls tmp -lh

In [ ]:
def smoothen(l, lookback=20):
  kernel = [1./lookback] * lookback
  return onp.convolve(l[0:1] * (lookback - 1) + l, kernel, "valid")

plt.plot(smoothen(loss_log, 100), label='mp')
plt.yscale('log')
plt.ylim(1e-2, 3e-1)
plt.legend()
plt.show()

In [ ]:
# check saved checkpoints
! ls tmp -lh

In [ ]:
#@title Optionally, load weights from savedmodel
# override file_path as it will be useful later on.
file_path = "savedmodels/mnist_weights"

network = init_network(tf.sigmoid, (l0, 50))
network.load_weights(file_path)

In [ ]:
# Now let's generate a training regime of 10 steps, and compare with SGD.

eval_pfw = create_new_p_fw()

def prepare_for_analysis(pfw):
  return tf.concat([tf.reshape(t, [-1]) for t in pfw], 0).numpy()

print(prepare_for_analysis(eval_pfw).shape)

In [ ]:
test_sp_bs = 100

with tf.device("/cpu:0"):
  dataset_sp = tf.data.Dataset.from_tensor_slices((x_train, y_train))
  dataset_sp = dataset_sp.batch(inner_batch_size).shuffle(1000).repeat()
  dataset_sp_iter = iter(dataset_sp)
  test_sp_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
  test_sp_ds = test_sp_ds.batch(test_sp_bs).repeat()
  test_sp_ds_iter = iter(test_sp_ds)


def get_accuracy(pfw, xe, ye):
  targets = tf.argmax(ye, axis=-1)
  res, _ = network.forward(pfw, xe)
  predictions = tf.argmax(res, axis=-1)

  tot_correct = tf.reduce_sum(tf.cast(tf.equal(predictions, targets), tf.float32))
  accuracy = tot_correct / ye.shape[0]

  return accuracy

In [ ]:
# MP step accuracy
all_MP_series = []
eval_tot_steps = 100
ev_losses = np.zeros([eval_tot_steps, loop_steps])
for s in range(eval_tot_steps):
  if s % 10 == 0:
    print("\nRepetition {}".format(s))
  MP_params_series = []

  mp_pfw = network.init(inner_batch_size)


  for i in range(loop_steps):
    # using test to train too, because we want to see the effect of learning
    # on the parameter space.
    xt, yt = next(dataset_sp_iter)
    
    mp_pfw, _= network.inner_update(mp_pfw, xt, yt)


    xe, ye = next(test_sp_ds_iter)
    accuracy = get_accuracy(mp_pfw, xe, ye)
    ev_losses[s, i] = accuracy

ev_losses_m = np.mean(ev_losses, axis=0)
ev_losses_sd = np.std(ev_losses, axis=0)
print("mean:", ev_losses_m)
print("sd:", ev_losses_sd)

ub = [m + sd for m, sd in zip(ev_losses_m, ev_losses_sd)]
lb = [m - sd for m, sd in zip(ev_losses_m, ev_losses_sd)]
plt.fill_between(range(1, len(ev_losses_m) + 1), ub, lb, alpha=.5)
plt.plot(range(1, len(ev_losses_m) + 1), ev_losses_m, label='eval loss')
plt.xlabel("num steps")
plt.ylabel("Accuracy")
plt.legend()

mp_baseline_m = ev_losses_m

In [ ]:
# ADAM network
adam_network = MPNetwork([MPDense(50, stateful=False),
                  MPActivation(tf.sigmoid, stateful=False),
                  MPDense(10, stateful=False),
                  MPSoftmax(stateful=False),
                  ],
                  MPCrossEntropyLoss(message_size, stateful=False))
adam_network.setup(in_dim=l0, message_size=message_size)



def get_adam_accuracy(pfw, xe, ye):
  targets = tf.argmax(ye, axis=-1)
  res, _ = adam_network.forward(pfw, xe)
  predictions = tf.argmax(res, axis=-1)

  tot_correct = tf.reduce_sum(tf.cast(tf.equal(predictions, targets), tf.float32))
  accuracy = tot_correct / ye.shape[0]

  return accuracy

In [ ]:
# SGD step

all_SGD_series = []
eval_tot_steps = 100
sgd_ev_losses = np.zeros([eval_tot_steps, loop_steps])
for s in range(eval_tot_steps):
  if s % 10 == 0:
    print("\nRepetition {}".format(s))
  SGD_params_series = []

  sgd_pfw = [tf.Variable(t) for t in adam_network.init()]

  adam_trainer = tf.keras.optimizers.SGD(0.1)

  @tf.function
  def step(xt, yt):
    with tf.GradientTape() as g:
      g.watch(sgd_pfw)
      y, _ = adam_network.forward(sgd_pfw, xt)
      l, _ = adam_network.compute_loss(y, yt)
      l = tf.reduce_mean(tf.reduce_sum(l, axis=[1]))

    grads = g.gradient(l, sgd_pfw)
    adam_trainer.apply_gradients(zip(grads, sgd_pfw))


  for i in range(loop_steps):
    # using test to train too, because we want to see the effect of learning
    # on the parameter space.
    xt, yt = next(dataset_sp_iter)
    step(xt, yt)

    if i % 1 == 0:
      xe, ye = next(test_sp_ds_iter)
      accuracy = get_adam_accuracy(sgd_pfw, xe, ye)
      sgd_ev_losses[s, i] = accuracy

ev_losses_m = np.mean(sgd_ev_losses, axis=0)
ev_losses_sd = np.std(sgd_ev_losses, axis=0)
print("mean:", ev_losses_m)
print("sd:", ev_losses_sd)

ub = [m + sd for m, sd in zip(ev_losses_m, ev_losses_sd)]
lb = [m - sd for m, sd in zip(ev_losses_m, ev_losses_sd)]
plt.fill_between(range(1, len(ev_losses_m) + 1), ub, lb, alpha=.5)
plt.plot(range(1, len(ev_losses_m) + 1), ev_losses_m, label='eval loss')
#plt.ylim(0.0, 0.04)
plt.xlabel("num steps")
plt.ylabel("Accuracy")
plt.legend()

sgd_m = ev_losses_m

In [ ]:
# Adam step

all_SGD_series = []
eval_tot_steps = 100
sgd_ev_losses = np.zeros([eval_tot_steps, loop_steps])
for s in range(eval_tot_steps):
  if s % 10 == 0:
    print("\nRepetition {}".format(s))
  SGD_params_series = []

  sgd_pfw = [tf.Variable(t) for t in adam_network.init()]


  adam_trainer = tf.keras.optimizers.Adam(0.01)

  @tf.function
  def step(xt, yt):
    with tf.GradientTape() as g:
      g.watch(sgd_pfw)
      y, _ = adam_network.forward(sgd_pfw, xt)
      l, _ = adam_network.compute_loss(y, yt)
      l = tf.reduce_mean(tf.reduce_sum(l, axis=[1]))

    grads = g.gradient(l, sgd_pfw)
    adam_trainer.apply_gradients(zip(grads, sgd_pfw))


  for i in range(loop_steps):
    # using test to train too, because we want to see the effect of learning
    # on the parameter space.
    xt, yt = next(dataset_sp_iter)
    step(xt, yt)

    if i % 1 == 0:
      xe, ye = next(test_sp_ds_iter)
      accuracy = get_adam_accuracy(sgd_pfw, xe, ye)
      sgd_ev_losses[s, i] = accuracy

ev_losses_m = np.mean(sgd_ev_losses, axis=0)
ev_losses_sd = np.std(sgd_ev_losses, axis=0)
print("mean:", ev_losses_m)
print("sd:", ev_losses_sd)

ub = [m + sd for m, sd in zip(ev_losses_m, ev_losses_sd)]
lb = [m - sd for m, sd in zip(ev_losses_m, ev_losses_sd)]
plt.fill_between(range(1, len(ev_losses_m) + 1), ub, lb, alpha=.5)
plt.plot(range(1, len(ev_losses_m) + 1), ev_losses_m, label='eval loss')
plt.xlabel("num steps")
plt.ylabel("Accuracy")
plt.legend()

adam_m = ev_losses_m

See what happens if you change activation


In [ ]:
new_activation = lambda x: tf.maximum(0.0, tf.sign(x))

l0 = PIC_L * PIC_L

new_network = init_network(new_activation, (l0, 50))
new_network.load_weights(file_path)

def get_accuracy(net, pfw, xe, ye):
  targets = tf.argmax(ye, axis=-1)
  res, _ = net.forward(pfw, xe)
  predictions = tf.argmax(res, axis=-1)

  tot_correct = tf.reduce_sum(tf.cast(tf.equal(predictions, targets), tf.float32))
  accuracy = tot_correct / ye.shape[0]

  return accuracy

In [ ]:
# MP step accuracy
all_MP_series = []
eval_tot_steps = 100
ev_losses = np.zeros([eval_tot_steps, loop_steps])
for s in range(eval_tot_steps):
  if s % 10 == 0:
    print("\nRepetition {}".format(s))
  MP_params_series = []

  mp_pfw = new_network.init(inner_batch_size)

  for i in range(loop_steps):
    # using test to train too, because we want to see the effect of learning
    # on the parameter space.
    xt, yt = next(dataset_sp_iter)
    
    mp_pfw, _ = new_network.inner_update(mp_pfw, xt, yt)

    xe, ye = next(test_sp_ds_iter)
    accuracy = get_accuracy(new_network, mp_pfw, xe, ye)
    ev_losses[s, i] = accuracy

ev_losses_m = np.mean(ev_losses, axis=0)
ev_losses_sd = np.std(ev_losses, axis=0)
print("mean:", ev_losses_m)
print("sd:", ev_losses_sd)

ub = [m + sd for m, sd in zip(ev_losses_m, ev_losses_sd)]
lb = [m - sd for m, sd in zip(ev_losses_m, ev_losses_sd)]
plt.fill_between(range(1, len(ev_losses_m) + 1), ub, lb, alpha=.5)
plt.plot(range(1, len(ev_losses_m) + 1), ev_losses_m, label='eval loss')
#plt.ylim(0.0, 0.04)
plt.xlabel("num steps")
plt.ylabel("Accuracy")
plt.legend()

mp_stepf_m = ev_losses_m

See what happens if you transfer the learned parameters


In [ ]:
# big version of MNIST
(x_train_b, y_train_b), (x_test_b, y_test_b) = load_mnist()
y_train_b = one_hottify(y_train_b)
y_test_b = one_hottify(y_test_b)

# standardize inputs:
train_mean, train_std = x_train_b.mean(), x_train_b.std()

x_train_b = (x_train_b - train_mean) / train_std
x_test_b = (x_test_b - train_mean) / train_std

x_train_b = x_train_b.reshape([-1, 28 * 28])
x_test_b = x_test_b.reshape([-1, 28* 28])

In [ ]:
new_network = init_network(tf.sigmoid, (28*28, 100))
new_network.load_weights(file_path)

In [ ]:
!ls tmp

In [ ]:
test_sp_bs = 100

with tf.device("/cpu:0"):
  dataset_sp_b = tf.data.Dataset.from_tensor_slices((x_train_b, y_train_b))
  dataset_sp_b = dataset_sp_b.batch(inner_batch_size).shuffle(1000).repeat()
  dataset_sp_b_iter = iter(dataset_sp_b)
  test_sp_b_ds = tf.data.Dataset.from_tensor_slices((x_test_b, y_test_b))
  test_sp_b_ds = test_sp_b_ds.batch(test_sp_bs).repeat()
  test_sp_b_ds_iter = iter(test_sp_b_ds)


def get_accuracy(pfw, xe, ye):
  targets = tf.argmax(ye, axis=-1)
  res, _ = new_network.forward(pfw, xe)
  predictions = tf.argmax(res, axis=-1)

  tot_correct = tf.reduce_sum(tf.cast(tf.equal(predictions, targets), tf.float32))
  accuracy = tot_correct / ye.shape[0]

  return accuracy

In [ ]:
# MP step accuracy
# beware, this is very slow.
all_MP_series = []
eval_tot_steps = 100
ev_losses = np.zeros([eval_tot_steps, loop_steps])
for s in range(eval_tot_steps):
  if s % 10 == 0:
    print("\nRepetition {}".format(s))
  MP_params_series = []

  mp_pfw = new_network.init(inner_batch_size)

  for i in range(loop_steps):
    # using test to train too, because we want to see the effect of learning
    # on the parameter space.
    xt, yt = next(dataset_sp_b_iter)
    
    mp_pfw, _ = new_network.inner_update(mp_pfw, xt, yt)

    xe, ye = next(test_sp_b_ds_iter)
    accuracy = get_accuracy(mp_pfw, xe, ye)
    ev_losses[s, i] = accuracy

ev_losses_m = np.mean(ev_losses, axis=0)
ev_losses_sd = np.std(ev_losses, axis=0)
print("mean:", ev_losses_m)
print("sd:", ev_losses_sd)

ub = [m + sd for m, sd in zip(ev_losses_m, ev_losses_sd)]
lb = [m - sd for m, sd in zip(ev_losses_m, ev_losses_sd)]
plt.fill_between(range(1, len(ev_losses_m) + 1), ub, lb, alpha=.5)
plt.plot(range(1, len(ev_losses_m) + 1), ev_losses_m, label='eval loss')
plt.xlabel("num steps")
plt.ylabel("Accuracy")
plt.legend()

mp_big_net_m = ev_losses_m

In [ ]:
# plot all together.

x_values = range(1, len(mp_baseline_m) + 1)

plt.plot(x_values, mp_baseline_m * 100, label='MPLP trained')
plt.plot(x_values, mp_stepf_m * 100, label='MPLP step function')
plt.plot(x_values, mp_big_net_m * 100, label='MPLP on bigger network')
plt.plot(x_values, sgd_m * 100, label='SGD')
plt.plot(x_values, adam_m * 100, label='Adam')
plt.xticks([4, 8, 12, 16, 20])

plt.xlabel("Number of steps")
plt.ylabel("Accuracy (%)")
plt.legend()


with open("tmp/mplp_evals.png", "wb") as fout:
  plt.savefig(fout)