In [ ]:
import numpy as np
import torch
from torch import nn
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = [10,6]
from copy import deepcopy
In [ ]:
seed = 0
plot = True
## Inner Optimisations
inner_stepsize = 0.02 # stepsize in inner SGD
inner_epochs = 1 # number of epochs of each inner SGD
inner_batchsize= 10 # Size of training minibatches
## Outer Optimisations
# initial stepsize of outer optimization, i.e., meta-optimization
outer_stepsize0 = 0.1
# number of outer updates; each iteration we sample one task and update on it
outer_steps = 10000
In [ ]:
torch.manual_seed(seed+0)
rng = np.random.RandomState(seed)
rng
In [ ]:
# Define task distribution
x_all = np.linspace(-5, 5, 50)[:,None] # All of the x points
In [ ]:
# Generate classification problems that we're going to learn about
def gen_task():
phase = rng.uniform(low=0, high=2*np.pi)
ampl = rng.uniform(2., 5.)
f_randomsine = lambda x : np.sin(x+phase)*ampl
return f_randomsine # i.e. return a random *function* to learn
In [ ]:
# Define model - this model is going to be easily *trainable* to
# solve the each of the problems we generate
model = nn.Sequential(
nn.Linear(1, 64),
nn.Tanh(), # Reptile paper uses ReLU, but Tanh gives slightly better results
nn.Linear(64, 64),
nn.Tanh(),
nn.Linear(64, 1),
)
def train_current_model_on_batch(x, y):
x = torch.tensor(x, dtype=torch.float32, requires_grad=True)
y = torch.tensor(y, dtype=torch.float32)
model.zero_grad()
ypred = model(x)
loss = (ypred - y).pow(2).mean()
loss.backward()
for param in model.parameters():
param.data -= inner_stepsize * param.grad.data
def predict_using_current_model(x):
x = torch.tensor(x, dtype=torch.float32)
return model(x).data.numpy()
In [ ]:
# Choose a fixed task and minibatch for visualization
f_plot = gen_task() # This is one specfic task
xtrain_plot = x_all[rng.choice(len(x_all), size=inner_batchsize)]
In [ ]:
# Reptile training loop
for outer_step in range(outer_steps):
weights_before = deepcopy(model.state_dict())
# Generate a task of the sort we want to learn about learning
f = gen_task() # This is a different task every time
y_all = f(x_all) # Get the correct outputs for the x_all values (i.e a dataset)
# Do SGD on this task + dataset
inds = rng.permutation(len(x_all)) # Shuffle data indices
for _ in range(inner_epochs):
for start in range(0, len(x_all), inner_batchsize):
batch_indices = inds[start:start+inner_batchsize]
train_current_model_on_batch(
x_all[batch_indices], y_all[batch_indices]
)
# Interpolate between current weights and trained weights from this task
# i.e. (weights_before - weights_after) is the meta-gradient
weights_after = model.state_dict()
outer_stepsize = outer_stepsize0*(1-outer_step/outer_steps) # linear schedule
# This updates the weights in the model -
# Not a 'training' gradient update, but one reflecting
# the training that occurred during the inner loop training
model.load_state_dict({name :
weights_before[name] + (weights_after[name]-weights_before[name])*outer_stepsize
for name in weights_before})
# Periodically plot the results on a particular task and minibatch
if plot and outer_step==0 or (outer_step+1) % 1000 == 0:
fig = plt.figure(figsize=(10,6))
ax = plt.subplot(111)
f = f_plot
weights_before = deepcopy(model.state_dict()) # save snapshot before evaluation
# Plot the initial model (having seen no data to train on for this task)
plt.plot(x_all, predict_using_current_model(x_all),
label="pred after 0", color=(0.5,0.5,1))
for inner_step in range(1, 32+1):
train_current_model_on_batch(xtrain_plot, f(xtrain_plot))
#if (inner_step) % 8 == 0:
if inner_step in [1, 2, 4, 8, 16, 32]:
frac = np.log(inner_step)/np.log(32.)
plt.plot(x_all, predict_using_current_model(x_all),
label="pred after %i" % (inner_step),
color=(frac, 0, 1-frac))
lossval = np.square(predict_using_current_model(x_all) - f(x_all)).mean()
plt.plot(x_all, f(x_all), label="true", color=(0,1,0))
plt.plot(xtrain_plot, f(xtrain_plot), "x", label="train", color="k")
plt.ylim(-4.5,4.5)
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.show()
print(f"outer_step {outer_step+1:6d}")
# would be better to average loss over a set of examples,
# but this is optimized for brevity
print(f"loss on plotted curve {lossval:.3f}")
print(f"-----------------------------")
model.load_state_dict(weights_before) # restore from snapshot
print("FINISHED")
In [ ]:
In [ ]: