In [ ]:
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = [10,6]
#from copy import deepcopy
In [ ]:
import tensorflow as tf
import tensorflow.contrib.eager as tfe
tf.enable_eager_execution()
print("TensorFlow version: {}".format(tf.VERSION))
print("Eager execution: {}".format(tf.executing_eagerly()))
In [ ]:
seed = 0
plot = True
## Inner Optimisations
inner_stepsize = 0.02 # stepsize in inner SGD
inner_epochs = 1 # number of epochs of each inner SGD
inner_batchsize= 10 # Size of training minibatches
## Outer Optimisations
# initial stepsize of outer optimization, i.e., meta-optimization
outer_stepsize0 = 0.1
# number of outer updates; each iteration we sample one task and update on it
outer_steps = 10000
In [ ]:
#torch.manual_seed(seed)
tf.set_random_seed(seed+0)
rng = np.random.RandomState(seed)
rng
In [ ]:
# Define task distribution
x_all = np.linspace(-5, 5, 50)[:,None] # All of the x points
In [ ]:
# Generate classification problems that we're going to learn about
def gen_task():
phase = rng.uniform(low=0, high=2*np.pi)
ampl = rng.uniform(2., 5.)
f_randomsine = lambda x : np.sin(x+phase)*ampl
return f_randomsine # i.e. return a random *function* to learn
In [ ]:
# Define model - this model is going to be easily *trainable* to
# solve the each of the problems we generate
#model = nn.Sequential(
# nn.Linear(1, 64),
# nn.Tanh(),
# nn.Linear(64, 64),
# nn.Tanh(),
# nn.Linear(64, 1),
#)
#inits=dict()
inits=dict(kernel_initializer='lecun_uniform', bias_initializer='lecun_uniform')
model = tf.keras.Sequential([
# Reptile paper uses ReLU, but Tanh gives slightly better results
tf.keras.layers.Dense(64, activation="tanh", input_shape=(1,), **inits), # input shape required
tf.keras.layers.Dense(64, activation="tanh", **inits),
tf.keras.layers.Dense(1, **inits) # NB: Keras 'knows' the previous layer sizes
])
# Alternatively (more PyTorch-y) :
class SineModel(tf.keras.Model):
def __init__(self):
super(SineModel, self).__init__()
# Input shape not required, since it is set the first time model is called
# Problem with shape if we try to define the inits, though
self.dense1 = tf.keras.layers.Dense(units=64, activation="tanh")
self.dense2 = tf.keras.layers.Dense(units=64, activation="tanh")
self.dense3 = tf.keras.layers.Dense(units=1)
def call(self, input):
"""Run the model."""
x = self.dense1(input)
x = self.dense2(x)
x = self.dense3(x)
return x
#model = SineModel()
#def totorch(x):
# return torch.autograd.Variable(torch.Tensor(x))
#def train_current_model_on_batch(x, y):
# x = totorch(x)
# y = totorch(y)
# model.zero_grad()
# ypred = model(x)
# loss = (ypred - y).pow(2).mean()
# loss.backward()
# for param in model.parameters():
# param.data -= inner_stepsize * param.grad.data
def train_current_model_on_batch(x, y):
x = tfe.Variable(x, dtype='float32')
y = tfe.Variable(y, dtype='float32')
with tfe.GradientTape() as tape:
ypred = model(x)
loss = tf.reduce_mean(tf.square(ypred - y))
grads = tape.gradient(loss, model.variables)
for param, grad in zip(model.variables, grads):
param.assign_sub(inner_stepsize * grad)
#def predict_using_current_model(x):
# x = totorch(x)
# return model(x).data.numpy()
def predict_using_current_model(x):
x = tfe.Variable(x, dtype='float32')
return model(x).numpy()
In [ ]:
model.summary()
In [ ]:
# Choose a fixed task and minibatch for visualization
f_plot = gen_task() # This is one specfic task
xtrain_plot = x_all[rng.choice(len(x_all), size=inner_batchsize)] # One illustrative batch
In [ ]:
# Reptile training loop
for outer_step in range(outer_steps):
#weights_before = deepcopy(model.state_dict())
weights_before = [ param.read_value() for param in model.variables ]
# Generate a task of the sort we want to learn about learning
f = gen_task() # This is a different task every time
y_all = f(x_all) # Get the correct outputs for the x_all values (i.e a dataset)
# Do SGD on this task + dataset
inds = rng.permutation(len(x_all)) # Shuffle data indices
for _ in range(inner_epochs):
for start in range(0, len(x_all), inner_batchsize):
batch_indices = inds[start:start+inner_batchsize]
train_current_model_on_batch(
x_all[batch_indices], y_all[batch_indices]
)
# Interpolate between current weights and trained weights from this task
# i.e. (weights_before - weights_after) is the meta-gradient
#weights_after = model.state_dict()
weights_after = [ param.read_value() for param in model.variables ]
outer_stepsize = outer_stepsize0*(1-outer_step/outer_steps) # linear schedule
# This updates the weights in the model -
# Not a 'training' gradient update, but one reflecting
# the training that occurred during the inner loop training
#model.load_state_dict({name :
# weights_before[name] + (weights_after[name]-weights_before[name])*outer_stepsize
# for name in weights_before})
for param, w_before, w_after in zip(model.variables, weights_before, weights_after):
param.assign( w_before + (w_after-w_before)*outer_stepsize )
# Periodically plot the results on a particular task and minibatch
if plot and outer_step==0 or (outer_step+1) % 1000 == 0:
fig = plt.figure(figsize=(10,6))
ax = plt.subplot(111)
f = f_plot
#weights_before = deepcopy(model.state_dict()) # save snapshot before evaluation
weights_saved = [ param.read_value() for param in model.variables ]
# Plot the initial model (having seen no data to train on for this task)
plt.plot(x_all, predict_using_current_model(x_all),
label="pred after 0", color=(0.5,0.5,1))
for inner_step in range(1, 32+1):
train_current_model_on_batch(xtrain_plot, f(xtrain_plot))
#if (inner_step) % 8 == 0:
if inner_step in [1, 2, 4, 8, 16, 32]:
frac = np.log(inner_step)/np.log(32.)
plt.plot(x_all, predict_using_current_model(x_all),
label="pred after %i" % (inner_step),
color=(frac, 0, 1-frac))
lossval = np.square(predict_using_current_model(x_all) - f(x_all)).mean()
plt.plot(x_all, f(x_all), label="true", color=(0,1,0))
plt.plot(xtrain_plot, f(xtrain_plot), "x", label="train", color="k")
plt.ylim(-4.5,4.5)
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.show()
print(f"outer_step {outer_step+1:6d}")
# would be better to average loss over a set of examples,
# but this is optimized for brevity
print(f"loss on plotted curve {lossval:.3f}")
print(f"-----------------------------")
#model.load_state_dict(weights_before) # restore from snapshot
for param, w_saved in zip(model.variables, weights_saved):
param.assign( w_saved )
print("FINISHED")
In [ ]:
#weights_before
In [ ]:
In [ ]:
In [ ]:
a = tfe.Variable([[1, 2], [3, 4]])
b = a.read_value()
d = a.numpy()
a.assign_add([[2, 2], [1, 1]] )
In [ ]:
b
In [ ]:
d
In [ ]: