Copyright 2020 Google LLC
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
In [ ]:
import os
import io
import numpy as np
import glob
import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()
tf.get_logger().setLevel('ERROR')
import matplotlib.pyplot as plt # visualization
import numpy as onp
from collections import defaultdict
import itertools
import tensorflow_datasets as tfds
import tensorflow.compat.v2 as tf
import matplotlib.pyplot as plt
import IPython.display as display
from PIL import Image
import numpy as np
import os
In [ ]:
!pip install --upgrade -e git+https://github.com/google-research/self-organising-systems.git#egg=mplp&subdirectory=mplp
In [ ]:
# symlink for saved models.
!ln -s src/mplp/mplp/savedmodels savedmodels
In [ ]:
from mplp.tf_layers import MPDense
from mplp.tf_layers import MPActivation
from mplp.tf_layers import MPSoftmax
from mplp.tf_layers import MPCrossEntropyLoss
from mplp.tf_layers import MPNetwork
from mplp.util import SamplePool
from mplp.training import TrainingRegime
from mplp.core import GRUBlock
from mplp.core import OutStandardizer
from mplp.core import StandardizeInputsAndStates
In [ ]:
def load_mnist():
(x_train, y_train),(x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
x_train = np.array(x_train).astype(np.float32)
y_train = np.array(y_train).astype(np.int64)
x_test = np.array(x_test).astype(np.float32)
y_test = np.array(y_test).astype(np.int64)
return (x_train, y_train),(x_test, y_test)
(x_train, y_train),(x_test, y_test) = load_mnist()
def one_hottify(dset):
one_hottified = onp.zeros([dset.shape[0], 10], dtype=onp.float32)
one_hottified[onp.arange(dset.size), dset] = 1.0
return one_hottified
def MNIST_generator(inputs, labels):
while True:
idx = pyrandom.randrange(len(inputs))
x = inputs[idx]
y = labels[idx]
yield x, y
PIC_L = 12
x_train = tf.image.resize(onp.expand_dims(x_train, -1), size=(PIC_L, PIC_L)).numpy()
x_test = tf.image.resize(onp.expand_dims(x_test, -1), size=(PIC_L, PIC_L)).numpy()
y_train = one_hottify(y_train)
y_test = one_hottify(y_test)
# standardize inputs:
train_mean, train_std = x_train.mean(), x_train.std()
x_train = (x_train - train_mean) / train_std
x_test = (x_test - train_mean) / train_std
x_train = x_train.reshape([-1, PIC_L * PIC_L])
x_test = x_test.reshape([-1, PIC_L * PIC_L])
In [ ]:
outer_batch_size = 1
inner_batch_size = 8
loop_steps = 20
img_h = img_w = 12
def resize(x, y):
x = tf.image.resize(x, [img_h, img_w])
x = tf.reshape(x, [-1, img_h * img_w])
return x, y
resized_mnist_train = tf.data.Dataset.from_tensor_slices(
(x_train, y_train))#.map(resize)
with tf.device("/cpu:0"):
dataset = resized_mnist_train.batch(inner_batch_size).batch(loop_steps).batch(outer_batch_size).shuffle(1000).cache().repeat()
ds_iter = iter(dataset)
xval_ds = resized_mnist_train.batch(inner_batch_size).batch(outer_batch_size).shuffle(1000).cache().repeat()
xval_ds_iter = iter(xval_ds)
test_ds = tf.data.Dataset.from_generator(
lambda: MNIST_generator(x_test, y_test),
output_types=(onp.float32, onp.float32))
test_ds = test_ds.batch(inner_batch_size)
test_ds_iter = iter(test_ds)
In [ ]:
message_size = 4
stateful = True
stateful_hidden_n = 7
import functools
l0 = PIC_L * PIC_L
def init_network(activation, sizes):
network = MPNetwork([MPDense(sizes[1], stateful=stateful, stateful_hidden_n=stateful_hidden_n),
MPActivation(activation, stateful=stateful, stateful_hidden_n=stateful_hidden_n),
MPDense(10, stateful=stateful, stateful_hidden_n=stateful_hidden_n),
MPSoftmax(stateful=stateful, stateful_hidden_n=stateful_hidden_n),
],
MPCrossEntropyLoss(message_size, stateful=stateful, stateful_hidden_n=stateful_hidden_n))
# create shared networks!
shared_params = {
"W_net": GRUBlock(x_dim=2 + message_size,
carry_n=1+message_size+stateful_hidden_n),
"b_net": GRUBlock(x_dim=2 + message_size,
carry_n=1+message_size+stateful_hidden_n),
# "W_out_std": OutStandardizer(scale_init_val=0.05),
# "b_out_std": OutStandardizer(scale_init_val=0.05),
# "W_in_std": StandardizeInputsAndStates(["W_b", "W_in"]),
# "b_in_std": StandardizeInputsAndStates(["b_b", "b_in"]),
}
network.setup(in_dim=sizes[0], message_size=message_size,
inner_batch_size=inner_batch_size,
shared_params=shared_params)
return network
network = init_network(tf.sigmoid, (l0, 50))
In [ ]:
def create_new_p_fw():
return network.init(inner_batch_size)
POOL_SIZE = 16
pool = SamplePool(ps=tf.stack(
[network.serialize_pfw(create_new_p_fw()) for _ in range(POOL_SIZE)]).numpy())
In [ ]:
loop_steps_t = tf.constant(loop_steps)
learning_schedule = 5e-4
training_regime = TrainingRegime(
network, heldout_weight=1.0, hint_loss_ratio=0.7, remember_loss_ratio=None)
last_step = 0
# minibatch init, allowing to initialize by looking at more
# than just one step.
# Likewise, this can be run multiple times to improve the initialization.
for j in range(5):
print("on", j)
stats = []
pfw = network.init(inner_batch_size)
x_b, y_b = next(ds_iter)
x_b, y_b = x_b[0], y_b[0]
for i in range(loop_steps):
pfw, stats_i = network.minibatch_init(x_b[i], y_b[i], x_b[i].shape[0], pfw=pfw)
stats.append(stats_i)
# update
network.update_statistics(stats, update_perc=0.5)
print("final mean:")
for p in tf.nest.flatten(pfw):
print(p.shape, tf.reduce_mean(p), tf.math.reduce_std(p))
# The outer loop here uses Adam. SGD/Momentum are more stable but way slower.
trainer = tf.keras.optimizers.Adam(learning_schedule)
loss_log = []
def smoothen(l, lookback=20):
# first of all, if it's a nan, change it to a high value
kernel = [1./lookback] * lookback
return np.convolve(l[0:1] * (lookback - 1) + l, kernel, "valid")
In [ ]:
!mkdir tmp
In [ ]:
training_steps = 100000
learning_schedule = 1e-4
@tf.function
def step(pfws, xts, yts, xes, yes, num_steps):
print("compiling")
with tf.GradientTape() as g:
l, _, _ = training_regime.batch_mp_loss(pfws, xts, yts, xes, yes, num_steps)
all_weights = network.get_trainable_weights()
grads = g.gradient(l, all_weights)
# Try grad clipping to avoid explosions.
grads = [g/(tf.norm(g)+1e-8) for g in grads]
trainer.apply_gradients(zip(grads, all_weights))
return l
import time
start_time = time.time()
for i in range(last_step + 1, last_step +1 + training_steps):
last_step = i
tmp_t = time.time()
xts, yts = next(ds_iter)
xes, yes = next(xval_ds_iter)
batch = pool.sample(outer_batch_size)
fwps = batch.ps
l = step(fwps, xts, yts, xes, yes, loop_steps_t)
loss_log.append(l)
if i % 50 == 0:
print(i)
print("--- %s seconds ---" % (time.time() - start_time))
if i % 500 == 0:
plt.plot(smoothen(loss_log, 100), label='mp')
plt.yscale('log')
plt.legend()
plt.show()
if i % 2500 == 0:
file_path = "tmp/weights"
network.save_weights(file_path, last_step)
print("--- %s seconds ---" % (time.time() - start_time))
np_batched_stl_mpv2_nograd_loss = time_series
In [ ]:
!ls tmp -lh
In [ ]:
def smoothen(l, lookback=20):
kernel = [1./lookback] * lookback
return onp.convolve(l[0:1] * (lookback - 1) + l, kernel, "valid")
plt.plot(smoothen(loss_log, 100), label='mp')
plt.yscale('log')
plt.ylim(1e-2, 3e-1)
plt.legend()
plt.show()
In [ ]:
# check saved checkpoints
! ls tmp -lh
In [ ]:
#@title Optionally, load weights from savedmodel
# override file_path as it will be useful later on.
file_path = "savedmodels/mnist_weights"
network = init_network(tf.sigmoid, (l0, 50))
network.load_weights(file_path)
In [ ]:
# Now let's generate a training regime of 10 steps, and compare with SGD.
eval_pfw = create_new_p_fw()
def prepare_for_analysis(pfw):
return tf.concat([tf.reshape(t, [-1]) for t in pfw], 0).numpy()
print(prepare_for_analysis(eval_pfw).shape)
In [ ]:
test_sp_bs = 100
with tf.device("/cpu:0"):
dataset_sp = tf.data.Dataset.from_tensor_slices((x_train, y_train))
dataset_sp = dataset_sp.batch(inner_batch_size).shuffle(1000).repeat()
dataset_sp_iter = iter(dataset_sp)
test_sp_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_sp_ds = test_sp_ds.batch(test_sp_bs).repeat()
test_sp_ds_iter = iter(test_sp_ds)
def get_accuracy(pfw, xe, ye):
targets = tf.argmax(ye, axis=-1)
res, _ = network.forward(pfw, xe)
predictions = tf.argmax(res, axis=-1)
tot_correct = tf.reduce_sum(tf.cast(tf.equal(predictions, targets), tf.float32))
accuracy = tot_correct / ye.shape[0]
return accuracy
In [ ]:
# MP step accuracy
all_MP_series = []
eval_tot_steps = 100
ev_losses = np.zeros([eval_tot_steps, loop_steps])
for s in range(eval_tot_steps):
if s % 10 == 0:
print("\nRepetition {}".format(s))
MP_params_series = []
mp_pfw = network.init(inner_batch_size)
for i in range(loop_steps):
# using test to train too, because we want to see the effect of learning
# on the parameter space.
xt, yt = next(dataset_sp_iter)
mp_pfw, _= network.inner_update(mp_pfw, xt, yt)
xe, ye = next(test_sp_ds_iter)
accuracy = get_accuracy(mp_pfw, xe, ye)
ev_losses[s, i] = accuracy
ev_losses_m = np.mean(ev_losses, axis=0)
ev_losses_sd = np.std(ev_losses, axis=0)
print("mean:", ev_losses_m)
print("sd:", ev_losses_sd)
ub = [m + sd for m, sd in zip(ev_losses_m, ev_losses_sd)]
lb = [m - sd for m, sd in zip(ev_losses_m, ev_losses_sd)]
plt.fill_between(range(1, len(ev_losses_m) + 1), ub, lb, alpha=.5)
plt.plot(range(1, len(ev_losses_m) + 1), ev_losses_m, label='eval loss')
plt.xlabel("num steps")
plt.ylabel("Accuracy")
plt.legend()
mp_baseline_m = ev_losses_m
In [ ]:
# ADAM network
adam_network = MPNetwork([MPDense(50, stateful=False),
MPActivation(tf.sigmoid, stateful=False),
MPDense(10, stateful=False),
MPSoftmax(stateful=False),
],
MPCrossEntropyLoss(message_size, stateful=False))
adam_network.setup(in_dim=l0, message_size=message_size)
def get_adam_accuracy(pfw, xe, ye):
targets = tf.argmax(ye, axis=-1)
res, _ = adam_network.forward(pfw, xe)
predictions = tf.argmax(res, axis=-1)
tot_correct = tf.reduce_sum(tf.cast(tf.equal(predictions, targets), tf.float32))
accuracy = tot_correct / ye.shape[0]
return accuracy
In [ ]:
# SGD step
all_SGD_series = []
eval_tot_steps = 100
sgd_ev_losses = np.zeros([eval_tot_steps, loop_steps])
for s in range(eval_tot_steps):
if s % 10 == 0:
print("\nRepetition {}".format(s))
SGD_params_series = []
sgd_pfw = [tf.Variable(t) for t in adam_network.init()]
adam_trainer = tf.keras.optimizers.SGD(0.1)
@tf.function
def step(xt, yt):
with tf.GradientTape() as g:
g.watch(sgd_pfw)
y, _ = adam_network.forward(sgd_pfw, xt)
l, _ = adam_network.compute_loss(y, yt)
l = tf.reduce_mean(tf.reduce_sum(l, axis=[1]))
grads = g.gradient(l, sgd_pfw)
adam_trainer.apply_gradients(zip(grads, sgd_pfw))
for i in range(loop_steps):
# using test to train too, because we want to see the effect of learning
# on the parameter space.
xt, yt = next(dataset_sp_iter)
step(xt, yt)
if i % 1 == 0:
xe, ye = next(test_sp_ds_iter)
accuracy = get_adam_accuracy(sgd_pfw, xe, ye)
sgd_ev_losses[s, i] = accuracy
ev_losses_m = np.mean(sgd_ev_losses, axis=0)
ev_losses_sd = np.std(sgd_ev_losses, axis=0)
print("mean:", ev_losses_m)
print("sd:", ev_losses_sd)
ub = [m + sd for m, sd in zip(ev_losses_m, ev_losses_sd)]
lb = [m - sd for m, sd in zip(ev_losses_m, ev_losses_sd)]
plt.fill_between(range(1, len(ev_losses_m) + 1), ub, lb, alpha=.5)
plt.plot(range(1, len(ev_losses_m) + 1), ev_losses_m, label='eval loss')
#plt.ylim(0.0, 0.04)
plt.xlabel("num steps")
plt.ylabel("Accuracy")
plt.legend()
sgd_m = ev_losses_m
In [ ]:
# Adam step
all_SGD_series = []
eval_tot_steps = 100
sgd_ev_losses = np.zeros([eval_tot_steps, loop_steps])
for s in range(eval_tot_steps):
if s % 10 == 0:
print("\nRepetition {}".format(s))
SGD_params_series = []
sgd_pfw = [tf.Variable(t) for t in adam_network.init()]
adam_trainer = tf.keras.optimizers.Adam(0.01)
@tf.function
def step(xt, yt):
with tf.GradientTape() as g:
g.watch(sgd_pfw)
y, _ = adam_network.forward(sgd_pfw, xt)
l, _ = adam_network.compute_loss(y, yt)
l = tf.reduce_mean(tf.reduce_sum(l, axis=[1]))
grads = g.gradient(l, sgd_pfw)
adam_trainer.apply_gradients(zip(grads, sgd_pfw))
for i in range(loop_steps):
# using test to train too, because we want to see the effect of learning
# on the parameter space.
xt, yt = next(dataset_sp_iter)
step(xt, yt)
if i % 1 == 0:
xe, ye = next(test_sp_ds_iter)
accuracy = get_adam_accuracy(sgd_pfw, xe, ye)
sgd_ev_losses[s, i] = accuracy
ev_losses_m = np.mean(sgd_ev_losses, axis=0)
ev_losses_sd = np.std(sgd_ev_losses, axis=0)
print("mean:", ev_losses_m)
print("sd:", ev_losses_sd)
ub = [m + sd for m, sd in zip(ev_losses_m, ev_losses_sd)]
lb = [m - sd for m, sd in zip(ev_losses_m, ev_losses_sd)]
plt.fill_between(range(1, len(ev_losses_m) + 1), ub, lb, alpha=.5)
plt.plot(range(1, len(ev_losses_m) + 1), ev_losses_m, label='eval loss')
plt.xlabel("num steps")
plt.ylabel("Accuracy")
plt.legend()
adam_m = ev_losses_m
In [ ]:
new_activation = lambda x: tf.maximum(0.0, tf.sign(x))
l0 = PIC_L * PIC_L
new_network = init_network(new_activation, (l0, 50))
new_network.load_weights(file_path)
def get_accuracy(net, pfw, xe, ye):
targets = tf.argmax(ye, axis=-1)
res, _ = net.forward(pfw, xe)
predictions = tf.argmax(res, axis=-1)
tot_correct = tf.reduce_sum(tf.cast(tf.equal(predictions, targets), tf.float32))
accuracy = tot_correct / ye.shape[0]
return accuracy
In [ ]:
# MP step accuracy
all_MP_series = []
eval_tot_steps = 100
ev_losses = np.zeros([eval_tot_steps, loop_steps])
for s in range(eval_tot_steps):
if s % 10 == 0:
print("\nRepetition {}".format(s))
MP_params_series = []
mp_pfw = new_network.init(inner_batch_size)
for i in range(loop_steps):
# using test to train too, because we want to see the effect of learning
# on the parameter space.
xt, yt = next(dataset_sp_iter)
mp_pfw, _ = new_network.inner_update(mp_pfw, xt, yt)
xe, ye = next(test_sp_ds_iter)
accuracy = get_accuracy(new_network, mp_pfw, xe, ye)
ev_losses[s, i] = accuracy
ev_losses_m = np.mean(ev_losses, axis=0)
ev_losses_sd = np.std(ev_losses, axis=0)
print("mean:", ev_losses_m)
print("sd:", ev_losses_sd)
ub = [m + sd for m, sd in zip(ev_losses_m, ev_losses_sd)]
lb = [m - sd for m, sd in zip(ev_losses_m, ev_losses_sd)]
plt.fill_between(range(1, len(ev_losses_m) + 1), ub, lb, alpha=.5)
plt.plot(range(1, len(ev_losses_m) + 1), ev_losses_m, label='eval loss')
#plt.ylim(0.0, 0.04)
plt.xlabel("num steps")
plt.ylabel("Accuracy")
plt.legend()
mp_stepf_m = ev_losses_m
In [ ]:
# big version of MNIST
(x_train_b, y_train_b), (x_test_b, y_test_b) = load_mnist()
y_train_b = one_hottify(y_train_b)
y_test_b = one_hottify(y_test_b)
# standardize inputs:
train_mean, train_std = x_train_b.mean(), x_train_b.std()
x_train_b = (x_train_b - train_mean) / train_std
x_test_b = (x_test_b - train_mean) / train_std
x_train_b = x_train_b.reshape([-1, 28 * 28])
x_test_b = x_test_b.reshape([-1, 28* 28])
In [ ]:
new_network = init_network(tf.sigmoid, (28*28, 100))
new_network.load_weights(file_path)
In [ ]:
!ls tmp
In [ ]:
test_sp_bs = 100
with tf.device("/cpu:0"):
dataset_sp_b = tf.data.Dataset.from_tensor_slices((x_train_b, y_train_b))
dataset_sp_b = dataset_sp_b.batch(inner_batch_size).shuffle(1000).repeat()
dataset_sp_b_iter = iter(dataset_sp_b)
test_sp_b_ds = tf.data.Dataset.from_tensor_slices((x_test_b, y_test_b))
test_sp_b_ds = test_sp_b_ds.batch(test_sp_bs).repeat()
test_sp_b_ds_iter = iter(test_sp_b_ds)
def get_accuracy(pfw, xe, ye):
targets = tf.argmax(ye, axis=-1)
res, _ = new_network.forward(pfw, xe)
predictions = tf.argmax(res, axis=-1)
tot_correct = tf.reduce_sum(tf.cast(tf.equal(predictions, targets), tf.float32))
accuracy = tot_correct / ye.shape[0]
return accuracy
In [ ]:
# MP step accuracy
# beware, this is very slow.
all_MP_series = []
eval_tot_steps = 100
ev_losses = np.zeros([eval_tot_steps, loop_steps])
for s in range(eval_tot_steps):
if s % 10 == 0:
print("\nRepetition {}".format(s))
MP_params_series = []
mp_pfw = new_network.init(inner_batch_size)
for i in range(loop_steps):
# using test to train too, because we want to see the effect of learning
# on the parameter space.
xt, yt = next(dataset_sp_b_iter)
mp_pfw, _ = new_network.inner_update(mp_pfw, xt, yt)
xe, ye = next(test_sp_b_ds_iter)
accuracy = get_accuracy(mp_pfw, xe, ye)
ev_losses[s, i] = accuracy
ev_losses_m = np.mean(ev_losses, axis=0)
ev_losses_sd = np.std(ev_losses, axis=0)
print("mean:", ev_losses_m)
print("sd:", ev_losses_sd)
ub = [m + sd for m, sd in zip(ev_losses_m, ev_losses_sd)]
lb = [m - sd for m, sd in zip(ev_losses_m, ev_losses_sd)]
plt.fill_between(range(1, len(ev_losses_m) + 1), ub, lb, alpha=.5)
plt.plot(range(1, len(ev_losses_m) + 1), ev_losses_m, label='eval loss')
plt.xlabel("num steps")
plt.ylabel("Accuracy")
plt.legend()
mp_big_net_m = ev_losses_m
In [ ]:
# plot all together.
x_values = range(1, len(mp_baseline_m) + 1)
plt.plot(x_values, mp_baseline_m * 100, label='MPLP trained')
plt.plot(x_values, mp_stepf_m * 100, label='MPLP step function')
plt.plot(x_values, mp_big_net_m * 100, label='MPLP on bigger network')
plt.plot(x_values, sgd_m * 100, label='SGD')
plt.plot(x_values, adam_m * 100, label='Adam')
plt.xticks([4, 8, 12, 16, 20])
plt.xlabel("Number of steps")
plt.ylabel("Accuracy (%)")
plt.legend()
with open("tmp/mplp_evals.png", "wb") as fout:
plt.savefig(fout)