In [ ]:
import numpy as np

import torch
from torch import nn

import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = [10,6]

from copy import deepcopy

In [ ]:
seed = 0
plot = True

## Inner Optimisations
inner_stepsize = 0.02 # stepsize in inner SGD
inner_epochs   = 1 # number of epochs of each inner SGD
inner_batchsize= 10 # Size of training minibatches

## Outer Optimisations
# initial stepsize of outer optimization, i.e., meta-optimization
outer_stepsize0 = 0.1 
# number of outer updates; each iteration we sample one task and update on it
outer_steps = 10000

In [ ]:
torch.manual_seed(seed+0)
rng = np.random.RandomState(seed)
rng

In [ ]:
# Define task distribution
x_all = np.linspace(-5, 5, 50)[:,None] # All of the x points

In [ ]:
# Generate classification problems that we're going to learn about
def gen_task(): 
    phase = rng.uniform(low=0, high=2*np.pi)
    ampl = rng.uniform(2., 5.)
    f_randomsine = lambda x : np.sin(x+phase)*ampl
    return f_randomsine # i.e. return a random *function* to learn

In [ ]:
# Define model - this model is going to be easily *trainable* to 
#  solve the each of the problems we generate
model = nn.Sequential(
    nn.Linear(1, 64),
    nn.Tanh(), # Reptile paper uses ReLU, but Tanh gives slightly better results
    nn.Linear(64, 64),
    nn.Tanh(),
    nn.Linear(64, 1),
)

def train_current_model_on_batch(x, y):
    x = torch.tensor(x, dtype=torch.float32, requires_grad=True)
    y = torch.tensor(y, dtype=torch.float32)
    model.zero_grad()
    ypred = model(x)
    loss = (ypred - y).pow(2).mean()
    loss.backward()
    for param in model.parameters():
        param.data -= inner_stepsize * param.grad.data

def predict_using_current_model(x):
    x = torch.tensor(x, dtype=torch.float32)
    return model(x).data.numpy()

In [ ]:
# Choose a fixed task and minibatch for visualization
f_plot = gen_task()  # This is one specfic task
xtrain_plot = x_all[rng.choice(len(x_all), size=inner_batchsize)]

In [ ]:
# Reptile training loop
for outer_step in range(outer_steps):
    weights_before = deepcopy(model.state_dict())
    
    # Generate a task of the sort we want to learn about learning
    f = gen_task()  # This is a different task every time
    y_all = f(x_all) # Get the correct outputs for the x_all values (i.e a dataset)
    
    # Do SGD on this task + dataset
    inds = rng.permutation(len(x_all))  # Shuffle data indices
    for _ in range(inner_epochs):
        for start in range(0, len(x_all), inner_batchsize):
            batch_indices = inds[start:start+inner_batchsize]
            train_current_model_on_batch(
                x_all[batch_indices], y_all[batch_indices]
            )
            
    # Interpolate between current weights and trained weights from this task
    #   i.e. (weights_before - weights_after) is the meta-gradient
    weights_after = model.state_dict()
    
    outer_stepsize = outer_stepsize0*(1-outer_step/outer_steps) # linear schedule
    
    # This updates the weights in the model -
    #   Not a 'training' gradient update, but one reflecting
    #   the training that occurred during the inner loop training
    model.load_state_dict({name : 
        weights_before[name] + (weights_after[name]-weights_before[name])*outer_stepsize 
        for name in weights_before})

    # Periodically plot the results on a particular task and minibatch
    if plot and outer_step==0 or (outer_step+1) % 1000 == 0:
        fig = plt.figure(figsize=(10,6))
        ax = plt.subplot(111)
        
        f = f_plot
        
        weights_before = deepcopy(model.state_dict()) # save snapshot before evaluation
        
        # Plot the initial model (having seen no data to train on for this task)
        plt.plot(x_all, predict_using_current_model(x_all), 
                 label="pred after 0", color=(0.5,0.5,1))

        for inner_step in range(1, 32+1):
            train_current_model_on_batch(xtrain_plot, f(xtrain_plot))
            #if (inner_step) % 8 == 0:
            if inner_step in [1, 2, 4, 8, 16, 32]:
                frac = np.log(inner_step)/np.log(32.)
                plt.plot(x_all, predict_using_current_model(x_all), 
                         label="pred after %i" % (inner_step), 
                         color=(frac, 0, 1-frac))
        lossval = np.square(predict_using_current_model(x_all) - f(x_all)).mean()

        plt.plot(x_all, f(x_all), label="true", color=(0,1,0))
        
        plt.plot(xtrain_plot, f(xtrain_plot), "x", label="train", color="k")
        plt.ylim(-4.5,4.5)
        
        box = ax.get_position()
        ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
        ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))        
        
        plt.show()
        
        print(f"outer_step             {outer_step+1:6d}")
        # would be better to average loss over a set of examples, 
        #   but this is optimized for brevity
        print(f"loss on plotted curve   {lossval:.3f}") 
        print(f"-----------------------------")

        model.load_state_dict(weights_before) # restore from snapshot 
print("FINISHED")

In [ ]:


In [ ]: