In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable as V
import pandas as pd
import random
import seaborn as sbs
from tqdm import tqdm_notebook as tqdm
sbs.set_style('darkgrid')
%matplotlib inline
import warnings
warnings.filterwarnings('ignore')

This notebook reproduces both MAML and the similar Reptile.

The Problem

https://towardsdatascience.com/fun-with-small-image-data-sets-8c83d95d0159

The goal of both of these algorithms is to learn to do well at the K-shot learning problem.

In K-shot learning, we need to train a neural network to generalize based on a very small number of examples (often on the order of 10 or so) instead of the often thousands of examples we see in datasets like ImageNet. However, in preparation for K-shot learning, you are allowed to train on many similar K-shot problems to learn the best way to generalize based on only K examples.

This is learning to learn or metalearning. We have already seen metalearning in my post on "Learning to Learn by Gradient Descent by Gradient Descent", which you can find here:

https://becominghuman.ai/paper-repro-learning-to-learn-by-gradient-descent-by-gradient-descent-6e504cc1c0de

The metalearning approach of both Reptile and MAML is to come up with an initialization for neural networks that is easily generalizable to similar tasks. This is different to "Learning to Learn by Gradient Descent by Gradient Descent" in which we weren't learning an initialization but rather an optimizer.

This approach is very similar to transfer learning, in which we train a network on, say, ImageNet, and it later turns out that fine-tuning this network makes it easy to learn another image dataset, with much less data. Indeed, transfer learning can be seen as a form of metalearning.

The difference here is that the initial network was trained with the explicit purpose of being easily generalizable, whereas transfer learning just "accidentally" happens to work, and thus might not work optimally.

Indeed, it is fairly easy to find a in which transfer learnings fails to learn a good initialization. For this we need to look at the 1D sine wave regression problem.

Sine Wave Regression

In this K-shot problem, each task consists in learning a modified sine function. Specifically, for each task, the underlying function will be of the form y = a sin(x + b), with both a and b chosen randomly, and the goal of our neural network is to learn to find y given x based on only 10 (x, y) pairs.

Let's write our sine wave task and plot a couple of examples:


In [2]:
class SineWaveTask:
    def __init__(self):
        self.a = np.random.uniform(0.1, 5.0)
        self.b = np.random.uniform(0, 2*np.pi)
        self.train_x = None
        
    def f(self, x):
        return self.a * np.sin(x + self.b)
        
    def training_set(self, size=10, force_new=False):
        if self.train_x is None and not force_new:
            self.train_x = np.random.uniform(-5, 5, size)
            x = self.train_x
        elif not force_new:
            x = self.train_x
        else:
            x = np.random.uniform(-5, 5, size)
        y = self.f(x)
        return torch.Tensor(x), torch.Tensor(y)
    
    def test_set(self, size=50):
        x = np.linspace(-5, 5, size)
        y = self.f(x)
        return torch.Tensor(x), torch.Tensor(y)
    
    def plot(self, *args, **kwargs):
        x, y = self.test_set(size=100)
        return plt.plot(x.numpy(), y.numpy(), *args, **kwargs)
    
SineWaveTask().plot()
SineWaveTask().plot()
SineWaveTask().plot()
plt.show()


To understand why this is going to be a problem for transfer learning, let's plot 1,000 of them:


In [3]:
for _ in range(1000):
    SineWaveTask().plot(color='black')


Looks like there is a lot of overlap at each x value, to say the least...

Since there are multiple possible values for each x across multiple tasks, if we train a single neural net to deal with multiple tasks at the same time, its best bet will simply be to return the average y value across all tasks for each x. What does that look like?


In [4]:
all_x, all_y = [], []

for _ in range(10000):
    curx, cury = SineWaveTask().test_set(size=100)
    all_x.append(curx.numpy())
    all_y.append(cury.numpy())

avg, = plt.plot(all_x[0], np.mean(all_y, axis=0))
rand, = SineWaveTask().plot()
plt.legend([avg, rand], ['Average', 'Random'])
plt.show()


The average is basically 0, which means a neural network trained on a lot of tasks would simply return 0 everywhere! It is unclear that this will actually help very much, and yet this is the transfer learning approach in this case...

Let's see how well it does by actually implementing the model:


In [5]:
TRAIN_SIZE = 10000
TEST_SIZE = 1000

class ModifiableModule(nn.Module):
    def params(self):
        return [p for _, p in self.named_params()]
    
    def named_leaves(self):
        return []
    
    def named_submodules(self):
        return []
    
    def named_params(self):
        subparams = []
        for name, mod in self.named_submodules():
            for subname, param in mod.named_params():
                subparams.append((name + '.' + subname, param))
        return self.named_leaves() + subparams
    
    def set_param(self, name, param):
        if '.' in name:
            n = name.split('.')
            module_name = n[0]
            rest = '.'.join(n[1:])
            for name, mod in self.named_submodules():
                if module_name == name:
                    mod.set_param(rest, param)
                    break
        else:
            setattr(self, name, param)
            
    def copy(self, other, same_var=False):
        for name, param in other.named_params():
            if not same_var:
                param = V(param.data.clone(), requires_grad=True)
            self.set_param(name, param)

class GradLinear(ModifiableModule):
    def __init__(self, *args, **kwargs):
        super().__init__()
        ignore = nn.Linear(*args, **kwargs)
        self.weights = V(ignore.weight.data, requires_grad=True)
        self.bias = V(ignore.bias.data, requires_grad=True)
        
    def forward(self, x):
        return F.linear(x, self.weights, self.bias)
    
    def named_leaves(self):
        return [('weights', self.weights), ('bias', self.bias)]

class SineModel(ModifiableModule):
    def __init__(self):
        super().__init__()
        self.hidden1 = GradLinear(1, 40)
        self.hidden2 = GradLinear(40, 40)
        self.out = GradLinear(40, 1)
        
    def forward(self, x):
        x = F.relu(self.hidden1(x))
        x = F.relu(self.hidden2(x))
        return self.out(x)
    
    def named_submodules(self):
        return [('hidden1', self.hidden1), ('hidden2', self.hidden2), ('out', self.out)]
    
SINE_TRAIN = [SineWaveTask() for _ in range(TRAIN_SIZE)]
SINE_TEST = [SineWaveTask() for _ in range(TEST_SIZE)]

ONE_SIDED_EXAMPLE = None
while ONE_SIDED_EXAMPLE is None:
    cur = SineWaveTask()
    x, _ = cur.training_set()
    x = x.numpy()
    if np.max(x) < 0 or np.min(x) > 0:
        ONE_SIDED_EXAMPLE = cur

SINE_TRANSFER = SineModel()

def sine_fit1(net, wave, optim=None, get_test_loss=False, create_graph=False, force_new=False):
    net.train()
    if optim is not None:
        optim.zero_grad()
    x, y = wave.training_set(force_new=force_new)
    loss = F.mse_loss(net(V(x[:, None])), V(y).unsqueeze(1))
    loss.backward(create_graph=create_graph, retain_graph=True)
    if optim is not None:
        optim.step()
    if get_test_loss:
        net.eval()
        x, y = wave.test_set()
        loss_test = F.mse_loss(net(V(x[:, None])), V(y))
        return loss.data.cpu().numpy()[0], loss_test.data.cpu().numpy()[0]
    return loss.data.cpu().numpy()#[0]

def fit_transfer(epochs=1):
    optim = torch.optim.Adam(SINE_TRANSFER.params())

    for _ in range(epochs):
        for t in random.sample(SINE_TRAIN, len(SINE_TRAIN)):
            sine_fit1(SINE_TRANSFER, t, optim)
            
fit_transfer()

In [6]:
def copy_sine_model(model):
    m = SineModel()
    m.copy(model)
    return m

def eval_sine_test(model, test, fits=(0, 1), lr=0.01):
    xtest, ytest = test.test_set()
    xtrain, ytrain = test.training_set()

    model = copy_sine_model(model)
    # Not sure if this should be Adam or SGD.
    optim = torch.optim.SGD(model.params(), lr)
        
    def get_loss(res):
        return F.mse_loss(res, V(ytest[:, None])).cpu().data.numpy()#[0]
    
    fit_res = []
    if 0 in fits:
        results = model(V(xtest[:, None]))
        fit_res.append((0, results, get_loss(results)))
    for i in range(np.max(fits)):
        sine_fit1(model, test, optim)
        if i + 1 in fits:
            results = model(V(xtest[:, None]))
            fit_res.append(
                (
                    i + 1, 
                    results,
                    get_loss(results)
                )
            )

    return fit_res

def plot_sine_test(model, test, fits=(0, 1), lr=0.01):
    xtest, ytest = test.test_set()
    xtrain, ytrain = test.training_set()

    fit_res = eval_sine_test(model, test, fits, lr)
    
    train, = plt.plot(xtrain.numpy(), ytrain.numpy(), '^')
    ground_truth, = plt.plot(xtest.numpy(), ytest.numpy())
    plots = [train, ground_truth]
    legend = ['Training Points', 'True Function']
    for n, res, loss in fit_res:
        cur, = plt.plot(xtest.numpy(), res.cpu().data.numpy()[:, 0], '--')
        plots.append(cur)
        legend.append(f'After {n} Steps')
    plt.legend(plots, legend)
    plt.show()
    
plot_sine_test(SINE_TRANSFER, SINE_TEST[0], fits=[0, 1, 10], lr=0.02)


Basically it looks like our transfer model learns a constant function and that it is really hard to fine tune it to something better than a constant function. It's not even clear that our transfer learning is any better than random initialization...


In [7]:
def plot_sine_learning(models, fits=(0, 1), lr=0.01, marker='s', linestyle='--'):
    data = {'model': [], 'fits': [], 'loss': [], 'set': []}
    for name, models in models:
        if not isinstance(models, list):
            models = [models]
        for n_model, model in enumerate(models):
            for n_test, test in enumerate(SINE_TEST):
                n_test = n_model * len(SINE_TEST) + n_test
                fit_res = eval_sine_test(model, test, fits, lr)
                for n, _, loss in fit_res:
                    data['model'].append(name)
                    data['fits'].append(n)
                    data['loss'].append(loss)
                    data['set'].append(n_test)
        
    ax = sbs.lineplot(x='fits', y='loss', hue='set',
        pd.DataFrame(data), condition='model', value='loss',
        time='fits', unit='set', marker=marker, linestyle=linestyle)
    
plot_sine_learning(
    [('Transfer', SINE_TRANSFER), ('Random', SineModel())],
    list(range(100)),
    marker='',
    linestyle='-'
)


---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-7-4ea6789ca1a2> in <module>
     22     list(range(100)),
     23     marker='',
---> 24     linestyle='-'
     25 )

<ipython-input-7-4ea6789ca1a2> in plot_sine_learning(models, fits, lr, marker, linestyle)
     14                     data['set'].append(n_test)
     15 
---> 16     ax = sbs.tsplot(
     17         pd.DataFrame(data), condition='model', value='loss',
     18         time='fits', unit='set', marker=marker, linestyle=linestyle)

AttributeError: module 'seaborn' has no attribute 'tsplot'

MAML

We now come to MAML, the first of the two algorithms we will look at today.

As mentioned before, we are trying to find a set of weights such that running gradient descent on similar tasks makes progress as quickly as possible. MAML takes this extremely literally by running one iteration of gradient descent and then updating the initial weights based on how much progress that one iteration made towards the true task. More concretely it:

  • Creates a copy of the initialization weights
  • Runs an iteration of gradient descent for a random task on the copy
  • Backpropagates the loss on a test set through the iteration of gradient descent and back to the initial weights, so that we can update the initial weights in a direction in which they would have been easier to update.

We thus need to take a gradient of a gradient, aka a second degree derivative in this process. Fortunately this is something that PyTorch supports now, unfortunately PyTorch makes it a bit awkward to update the parameters of a model in a way that we can still run gradient descent through them (we already saw this is "Learning to Learn by Gradient Descent by Gradient Descent"), which explains the weird way in which the model is written.

Because we are going to use second derivatives, we need to make sure that the computational graph that allowed us to compute the original gradients stays around, which is why we pass create_graph=True to .backward().

The code below also implements first order MAML, which we explain later:


In [ ]:
def maml_sine(model, epochs, lr_inner=0.01, batch_size=1, first_order=False):
    optimizer = torch.optim.Adam(model.params())
    
    for _ in tqdm(range(epochs)):
        # Note: the paper doesn't specify the meta-batch size for this task,
        # so I just use 1 for now.
        for i, t in enumerate(random.sample(SINE_TRAIN, len(SINE_TRAIN))):
            new_model = SineModel()
            new_model.copy(model, same_var=True)
            loss = sine_fit1(new_model, t, create_graph=not first_order)
            for name, param in new_model.named_params():
                grad = param.grad
                if first_order:
                    grad = V(grad.detach().data)
                new_model.set_param(name, param - lr_inner * grad)
                        
            sine_fit1(new_model, t, force_new=True)

            if (i + 1) % batch_size == 0:
                optimizer.step()
                optimizer.zero_grad()
                    
SINE_MAML = [SineModel() for _ in range(5)]

for m in SINE_MAML:
    maml_sine(m, 4)

In [ ]:
plot_sine_test(SINE_MAML[0], SINE_TEST[0], fits=[0, 1, 10], lr=0.01)
plt.show()

plot_sine_learning(
    [('Transfer', SINE_TRANSFER), ('MAML', SINE_MAML[0]), ('Random', SineModel())],
    list(range(10)),
)
plt.show()

plot_sine_test(SINE_MAML[0], ONE_SIDED_EXAMPLE, fits=[0, 1, 10], lr=0.01)
plt.show()

So MAML works much better than transfer learning or random initialization for this problem. Yay!

However, it is a bit annoying that we have to use second order derivatives for this... it forces the code to be complicated and it also makes things a fair bit slower (around 33% according to the paper, which matches what we shall see here).

Is there an approximation of MAML that doesn't use the second order derivatives? Of course, we can simply pretend that the gradients that we used for the inner gradient descent just came out of nowhere, and thus just improve the initial parameters without taking into account these second order derivatives, which is what we did before by handling the first_order parameter.

So how good is this first order approximation? Almost as good as the original MAML, as it turns out!


In [ ]:
SINE_MAML_FIRST_ORDER = [SineModel() for _ in range(5)]

for m in SINE_MAML_FIRST_ORDER:
    maml_sine(m, 4, first_order=True)

In [ ]:
plot_sine_test(SINE_MAML_FIRST_ORDER[0], SINE_TEST[0], fits=[0, 1, 10], lr=0.01)
plt.show()

plot_sine_learning(
    [('MAML', SINE_MAML), ('MAML First Order', SINE_MAML_FIRST_ORDER)],
    list(range(10)),
)
plt.show()

plot_sine_test(SINE_MAML_FIRST_ORDER[0], ONE_SIDED_EXAMPLE, fits=[0, 1, 10], lr=0.01)
plt.show()

Reptile

The first order approximation for MAML tells us that something interesting is going on: after all, it seems like how the gradients were generated should be relevant for a good initialization, and yet it apparently isn't so much.

Reptile takes this idea even further by telling us to do the following: run SGD for a few iterations on a given task, and then move your initialization weights a little bit in the direction of the weights you obtained after your k iterations of SGD. An algorithm so simple, it takes only a couple lines of pseudocode:

When I first read this, I was quite consternated: isn't this the same as training your weights alternatively on each task, just like in transfer learning? How would this ever work?

Indeed, the Reptile paper anticipates this very reaction:

You might be thinking “isn’t this the same as training on the expected loss Eτ [Lτ]?” and then checking if the date is April 1st.

As it happens, I am writing this on April 2nd, so this is all serious. So what's going on?

Well, indeed if we had run SGD for a single iteration, we would have something equivalent to the transfer learning described above, but we aren't we are using a few iterations, and so indeed the weights we update towards each time actually depend indirectly on the second derivatives of the loss, similar to MAML.

Ok, but still, why would this work? Well Reptile provides a compelling intuition for this: for each task, there are weights that are optimal. Indeed, there are probably many sets of weights that are optimal. This means that if you take several tasks, there should be a set of weights for which the distance to at least one optimal set of weights for each task is minimal. This set of weights is where we want to initialize our networks, since it is likely to be the one for which the least work is necessary to reach the optimum for any task. This is the set of weights that Reptile finds.

We can see this expressed visually in the following image: the two black lines represent the sets of optimal weights for two different tasks, while the gray line represents the initialization weights. Reptile tries to get the initialization weights closer and closer to the point where the optimal weights are nearest to each other.

Let's now implement Reptile and compare it to MAML:


In [ ]:
def reptile_sine(model, epochs, lr_inner=0.01, lr_outer=0.001, k=32, batch_size=32):
    optimizer = torch.optim.Adam(model.params(), lr=lr_outer)
    
    name_to_param = dict(model.named_params())
    
    for _ in tqdm(range(epochs)):
        for i, t in enumerate(random.sample(SINE_TRAIN, len(SINE_TRAIN))):
            new_model = SineModel()
            new_model.copy(model)
            inner_optim = torch.optim.SGD(new_model.params(), lr=lr_inner)
            for _ in range(k):
                sine_fit1(new_model, t, inner_optim)
                        
            for name, param in new_model.named_params():
                cur_grad = (name_to_param[name].data - param.data) / k / lr_inner
                if name_to_param[name].grad is None:
                    name_to_param[name].grad = V(torch.zeros(cur_grad.size()))
                name_to_param[name].grad.data.add_(cur_grad / batch_size)
#                 if (i + 1) % 500 == 0:
#                     print(name_to_param[name].grad)
                
            if (i + 1) % batch_size == 0:
                to_show = name_to_param['hidden1.bias']
                optimizer.step()
                optimizer.zero_grad()

SINE_REPTILE = [SineModel() for _ in range(5)]

for m in SINE_REPTILE:
    reptile_sine(m, 4, k=3, batch_size=1)

In [ ]:
plot_sine_test(SINE_REPTILE[0], SINE_TEST[0], fits=[0, 1, 10], lr=0.01)
plt.show()

plot_sine_learning(
    [('MAML', SINE_MAML), ('MAML First Order', SINE_MAML_FIRST_ORDER), ('Reptile', SINE_REPTILE)],
    list(range(32)),
)
plt.show()

plot_sine_test(SINE_REPTILE[0], ONE_SIDED_EXAMPLE, fits=[0, 1, 10], lr=0.01)
plt.show()

It looks like Reptile does indeed achieve similar or even slightly better performance to MAML with a much simpler and slightly faster algorithm! Very exciting. And this applies to many more problems than just this toy example of sine waves! For more details, I really do recommend you read the paper. At this point, you should have enough background to understand them quite easily.


In [ ]:


In [ ]: