In [1]:
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import torch
from torch.autograd import Variable

from tqdm import tqdm, trange

import copy

from sklearn.metrics import r2_score

In [2]:
sns.set_context("paper")
sns.set_style("ticks")

np.random.seed(101217209)
torch.manual_seed(101217209)


Out[2]:
<torch._C.Generator at 0x7f9d553305e8>

In [3]:
def get_data(N, min_x, max_x):
    w, b = np.random.randn(2)
    def true_model(X):
        lin_factor = w * X + b
        y = np.sin(lin_factor*10)*(lin_factor**2)
        return y
    X_true = np.arange(min_x, max_x, 0.01)
    y_true = true_model(X_true)
    span = (max_x - min_x)
    scale = 0.2
    X_obs = min_x + span*scale + np.random.rand(N)*(span - 2*scale*span)
    y_obs = true_model(X_obs) + np.random.randn(N)*0.1
    
    y_mean, y_std = y_obs.mean(), y_obs.std()
    y_obs = (y_obs - y_mean)/y_std
    y_true = (y_true - y_mean)/y_std
    return (X_obs, y_obs, X_true, y_true), (w, b, true_model)

In [4]:
N = 40
min_x, max_x = -1, 1
(X_obs, y_obs, X_true, y_true), (w, b, true_model) = get_data(N, min_x, max_x)

In [5]:
class SimpleModel(torch.nn.Module):
    def __init__(self, p=0.05, decay=0.001, non_linearity=torch.nn.ReLU):
        super(SimpleModel, self).__init__()
        self.dropout_p = p
        self.decay = decay
        self.f = torch.nn.Sequential(
            torch.nn.Linear(1,20),
            torch.nn.ReLU(),
            torch.nn.Dropout(p=self.dropout_p),
            torch.nn.Linear(20, 20),
            non_linearity(),
            torch.nn.Dropout(p=self.dropout_p),
            torch.nn.Linear(20,1)
        )
    def forward(self, X):
        X = Variable(torch.Tensor(X), requires_grad=False)
        return self.f(X)

def ewc_criterion(model, fisher_matrix, prior_params, alpha=1):
    ewc_loss = Variable(torch.zeros(1))
    for p, fm, prior in zip(model.parameters(), fisher_matrix, prior_params):
        param_loss = torch.pow(p - Variable(prior, requires_grad=False), 2)
        fisher_scaled_loss = alpha*Variable(fm, requires_grad=False).mul(param_loss)
        ewc_loss += fisher_scaled_loss.sum()
    return ewc_loss    
    
def get_criterion_optimizer(model):
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(
        model.parameters(),
        #lr=0.001,
        #momentum=0.5,
        weight_decay=model.decay)
    return criterion, optimizer


def plot_predictions(model, model_simple, ax=None, uncertainty=0, N=0, l2=0.005):
    if ax is None:
        fig, ax = plt.subplots(1,1)
    ax.plot(X_obs[:10], y_obs[:10], ls="none", marker="o", color="b", label="task 1")
    ax.plot(X_obs[10:20], y_obs[10:20], ls="none", marker="d", color="cyan", label="task 2")
    ax.plot(X_obs[20:], y_obs[20:], ls="none", marker="s", color="magenta", label="task 3")
    ax.plot(X_true, y_true, ls="-", lw=1, color="0.5", label="true")

    for model_t, model_name, color in [
        (model, "ewc", "r"),
        (model_simple, "simple", "g")
    ]:
        if uncertainty > 0:
            model_t.train()
            y_preds = []
            for i in trange(uncertainty):
                y_pred = model_t(X_true[:, np.newaxis]).data.numpy()
                y_preds.append(y_pred)
            y_preds = np.hstack(y_preds)
            y_pred = y_preds.mean(axis=1)
            y_var = y_preds.var(axis=1)
            tau = l2 * (model_t.dropout_p) / (2*N*model_t.decay)
            y_var += (1/tau)
            y_std = np.sqrt(y_var) #+ (1/tau)
            for i in range(4):
                ax.fill_between(
                    X_true,
                    y_pred - y_std * ((i+1)/2),
                    y_pred + y_std * ((i+1)/2),
                    color=color,
                    alpha=0.1**(i+1)
                )
        else:
            model_t.eval()
            y_pred = model_t(X_true[:, np.newaxis]).data.numpy()
        r_squared = r2_score(y_true, y_pred)
        ax.plot(X_true, y_pred, ls="--", lw=1, color=color,
                label="{} ({:.3f})".format(model_name, r_squared))
    ax.legend()
    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    sns.despine(offset=10)
    return ax
    
def plot_losses(losses, losses_simple, ax=None):
    if ax is None:
        fig, ax = plt.subplots(1,1,figsize=(10,4))
    ax.plot(losses, lw=0.1, color='r', label='ewc')
    ax.plot(losses_simple, lw=0.1, color='0.5', label='simple')
    ax.legend()
    ax.set_xlabel("Epochs")
    ax.set_ylabel("Loss")
    sns.despine(offset=10)
    return ax


def get_fisher_matrix(model, criterion, X, Y, old_fisher_matrix=None):
    fisher_matrix = [torch.zeros(p.data.shape) for p in model.parameters()]
    counts = 0
    #print(X.shape, Y.shape)
    for x,y in tqdm(zip(X[:, np.newaxis, np.newaxis], Y[:, np.newaxis])):
        criterion.zero_grad()
        #print(x.shape, y.shape)
        y_pred = model(x)
        #print(y_pred.data.shape)
        loss = criterion(y_pred, Variable(torch.Tensor(y), requires_grad=False))
        loss.backward()
        for i, p in enumerate(model.parameters()):
            grad_squared = torch.pow(p.grad.data, 2)
            #print(grad_squared.shape, fisher_matrix[i].shape)
            fisher_matrix[i] += grad_squared
        counts += 1
    #print(grad_squared.shape, fisher_matrix[i].shape)
    ## Calculate the means
    fisher_matrix = [k/counts for k in fisher_matrix]
    if old_fisher_matrix:
        fisher_matrix = [k + o for k,o in zip(fisher_matrix, old_fisher_matrix)]
    return fisher_matrix
    
    
def train_step(
    model, X, Y, criterion, optimizer,
    fisher_matrix=None, model_params=None,
    ecw_alpha=0, clip=10):
    y_pred = model(X[:, np.newaxis])
    loss = criterion(y_pred, Variable(torch.Tensor(Y), requires_grad=False))
    train_loss = loss.data.numpy()[0]
    if ecw_alpha > 0:
        ewc_loss = ewc_criterion(model, fisher_matrix, model_params, alpha=ecw_alpha)
        loss += ewc_loss
    loss.backward()
    torch.nn.utils.clip_grad_norm(model.parameters(),clip)
    optimizer.step()
    return train_loss

In [6]:
def train_plot_compare(X_obs, y_obs, prefix="IID", clip=10, uncertainty=0, l2=0.01, **kwargs):
    model = SimpleModel(**kwargs)
    model_simple = copy.deepcopy(model)
    
    fig, ax = plt.subplots(1,3, sharex=True, sharey=True, figsize=(18, 6))
    fig_loss, ax_loss = plt.subplots(3,1, sharex=True, sharey=False, figsize=(10, 12))

    ## Task 1

    fisher_matrix = None
    model_params = None
    N = 0
    for i, X, Y, alpha in zip(
        range(3),
        [X_obs[:10], X_obs[10:20], X_obs[20:]],
        [y_obs[:10], y_obs[10:20], y_obs[20:]],
        [0, 1e6, 1e6] # Alpha should increase with tasks
    ):
        title = "Task {}".format(" + ".join(str(k+1) for k in range(i+1)))
        model.train()
        criterion, optimizer = get_criterion_optimizer(model)

        model_simple.train()
        criterion_simple, optimizer_simple = get_criterion_optimizer(model_simple)

        losses = []
        losses_simple = []
        for epoch in trange(10000):
            loss = train_step(
                model, X, Y,
                criterion, optimizer,
                fisher_matrix=fisher_matrix,
                model_params=model_params,
                ecw_alpha=alpha, clip=clip
            )
            losses.append(loss)

            loss = train_step(
                model_simple, X, Y, criterion_simple, optimizer_simple, clip=clip)
            losses_simple.append(loss)
        plot_losses(losses, losses_simple, ax=ax_loss[i])
        N = X.shape[0]
        plot_predictions(model, model_simple, ax=ax[i], uncertainty=uncertainty, N=N, l2=l2)
        ax_loss[i].set_title(title)
        ax[i].set_title(title)
        print(title)

        model.eval()
        fisher_matrix = get_fisher_matrix(model, criterion, X, Y, old_fisher_matrix=fisher_matrix)
        model_params = [copy.deepcopy(p.data) for p in model.parameters()]

    fig.tight_layout()
    fig_loss.tight_layout()
    
    sns.despine(offset=10, fig=fig)
    sns.despine(offset=10, fig=fig_loss)

    fig.savefig("{}_EWC_predictions.pdf".format(prefix), bbox_inches="tight")
    fig.savefig("{}_EWC_losses.pdf".format(prefix), bbox_inches="tight")
    
    return model, model_simple

IID samples


In [7]:
plt.plot(X_obs[:10], y_obs[:10], ls="none", marker="o", color="b", label="task 1")
plt.plot(X_obs[10:20], y_obs[10:20], ls="none", marker="d", color="cyan", label="task 2")
plt.plot(X_obs[20:], y_obs[20:], ls="none", marker="s", color="magenta", label="task 3")
plt.plot(X_true, y_true, ls="-", lw=1, color="0.5", label="true")
plt.legend()
sns.despine(offset=10)


/home/napsternxg/anaconda3/lib/python3.5/site-packages/matplotlib/font_manager.py:1297: UserWarning: findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans
  (prop.get_family(), self.defaultFamily[fontext]))

In [8]:
model, model_simple = train_plot_compare(
    X_obs, y_obs, prefix="IID",
    p=0.1, decay=1e-6, non_linearity=torch.nn.ReLU)


100%|██████████| 10000/10000 [00:21<00:00, 463.59it/s]
10it [00:00, 909.85it/s]
  0%|          | 26/10000 [00:00<00:39, 253.09it/s]
Task 1
100%|██████████| 10000/10000 [00:37<00:00, 269.22it/s]
10it [00:00, 786.32it/s]
  0%|          | 27/10000 [00:00<00:37, 262.56it/s]
Task 1 + 2
100%|██████████| 10000/10000 [00:38<00:00, 262.71it/s]
20it [00:00, 1156.36it/s]
/home/napsternxg/anaconda3/lib/python3.5/site-packages/matplotlib/font_manager.py:1297: UserWarning: findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans
  (prop.get_family(), self.defaultFamily[fontext]))
Task 1 + 2 + 3

In [9]:
"""
model.eval()
fisher_matrix = get_fisher_matrix(model, criterion, X_obs[:10], y_obs[:10])
model_params = [copy.deepcopy(p.data) for p in model.parameters()]

fig = plt.figure(0)
ax = fig.add_subplot(111)


el = plt.matplotlib.patches.Ellipse(
    xy=[
        model_params[0].squeeze().numpy()[0],
        model_params[1].numpy()[0]
    ],
    width=2./fisher_matrix[0].squeeze().numpy()[0],
    height=2./fisher_matrix[1].numpy()[0])
ax.add_artist(el)
#el.set_clip_box(ax.bbox)
el.set_color('0.5')
el.set_alpha(0.1)
ax.scatter(model_params[0].squeeze().numpy(), model_params[1].numpy(), s=1)
"""


Out[9]:
"\nmodel.eval()\nfisher_matrix = get_fisher_matrix(model, criterion, X_obs[:10], y_obs[:10])\nmodel_params = [copy.deepcopy(p.data) for p in model.parameters()]\n\nfig = plt.figure(0)\nax = fig.add_subplot(111)\n\n\nel = plt.matplotlib.patches.Ellipse(\n    xy=[\n        model_params[0].squeeze().numpy()[0],\n        model_params[1].numpy()[0]\n    ],\n    width=2./fisher_matrix[0].squeeze().numpy()[0],\n    height=2./fisher_matrix[1].numpy()[0])\nax.add_artist(el)\n#el.set_clip_box(ax.bbox)\nel.set_color('0.5')\nel.set_alpha(0.1)\nax.scatter(model_params[0].squeeze().numpy(), model_params[1].numpy(), s=1)\n"

Sequential samples


In [10]:
sorted_idx = np.argsort(X_obs)
X_obs = X_obs[sorted_idx]
y_obs = y_obs[sorted_idx]

In [11]:
plt.plot(X_obs[:10], y_obs[:10], ls="none", marker="o", color="b", label="task 1")
plt.plot(X_obs[10:20], y_obs[10:20], ls="none", marker="d", color="cyan", label="task 2")
plt.plot(X_obs[20:], y_obs[20:], ls="none", marker="s", color="magenta", label="task 3")
plt.plot(X_true, y_true, ls="-", lw=1, color="0.5", label="true")
plt.legend()
sns.despine(offset=10)


/home/napsternxg/anaconda3/lib/python3.5/site-packages/matplotlib/font_manager.py:1297: UserWarning: findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans
  (prop.get_family(), self.defaultFamily[fontext]))

In [12]:
model, model_simple = train_plot_compare(
    X_obs, y_obs, prefix="Seq",
    p=0.1, decay=1e-6, non_linearity=torch.nn.ReLU)


100%|██████████| 10000/10000 [00:21<00:00, 465.14it/s]
10it [00:00, 959.88it/s]
  0%|          | 28/10000 [00:00<00:36, 273.25it/s]
Task 1
100%|██████████| 10000/10000 [00:37<00:00, 267.17it/s]
10it [00:00, 1113.88it/s]
  0%|          | 27/10000 [00:00<00:38, 261.83it/s]
Task 1 + 2
100%|██████████| 10000/10000 [00:38<00:00, 258.46it/s]
20it [00:00, 1154.57it/s]
/home/napsternxg/anaconda3/lib/python3.5/site-packages/matplotlib/font_manager.py:1297: UserWarning: findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans
  (prop.get_family(), self.defaultFamily[fontext]))
Task 1 + 2 + 3

Improvements

Before each forward pass

$$ \begin{equation} \sigma_{\theta} = F_{\theta}^{-1/2}\\ \epsilon \sim \mathcal{N}(0, 1)\\ \theta = \theta + \sigma*\epsilon \end{equation} $$

In [13]:
model, model_simple = train_plot_compare(
    X_obs, y_obs, prefix="Seq", uncertainty=100, l2=0.01,
    p=0.1, decay=1e-6, non_linearity=torch.nn.ReLU
)


100%|██████████| 10000/10000 [00:21<00:00, 475.00it/s]
100%|██████████| 100/100 [00:00<00:00, 2015.19it/s]
100%|██████████| 100/100 [00:00<00:00, 2179.43it/s]
10it [00:00, 1115.98it/s]
  0%|          | 24/10000 [00:00<00:42, 236.39it/s]
Task 1
100%|██████████| 10000/10000 [00:37<00:00, 266.63it/s]
100%|██████████| 100/100 [00:00<00:00, 2124.52it/s]
100%|██████████| 100/100 [00:00<00:00, 2063.45it/s]
10it [00:00, 1051.86it/s]
  0%|          | 26/10000 [00:00<00:38, 259.60it/s]
Task 1 + 2
100%|██████████| 10000/10000 [00:37<00:00, 264.95it/s]
100%|██████████| 100/100 [00:00<00:00, 1057.58it/s]
100%|██████████| 100/100 [00:00<00:00, 2272.35it/s]
20it [00:00, 1272.29it/s]
/home/napsternxg/anaconda3/lib/python3.5/site-packages/matplotlib/font_manager.py:1297: UserWarning: findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans
  (prop.get_family(), self.defaultFamily[fontext]))
Task 1 + 2 + 3

In [ ]: