In [1]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch.autograd import Variable
from tqdm import tqdm, trange
import copy
from sklearn.metrics import r2_score
In [2]:
sns.set_context("paper")
sns.set_style("ticks")
np.random.seed(101217209)
torch.manual_seed(101217209)
Out[2]:
In [3]:
def get_data(N, min_x, max_x):
w, b = np.random.randn(2)
def true_model(X):
lin_factor = w * X + b
y = np.sin(lin_factor*10)*(lin_factor**2)
return y
X_true = np.arange(min_x, max_x, 0.01)
y_true = true_model(X_true)
span = (max_x - min_x)
scale = 0.2
X_obs = min_x + span*scale + np.random.rand(N)*(span - 2*scale*span)
y_obs = true_model(X_obs) + np.random.randn(N)*0.1
y_mean, y_std = y_obs.mean(), y_obs.std()
y_obs = (y_obs - y_mean)/y_std
y_true = (y_true - y_mean)/y_std
return (X_obs, y_obs, X_true, y_true), (w, b, true_model)
In [4]:
N = 40
min_x, max_x = -1, 1
(X_obs, y_obs, X_true, y_true), (w, b, true_model) = get_data(N, min_x, max_x)
In [5]:
class SimpleModel(torch.nn.Module):
def __init__(self, p=0.05, decay=0.001, non_linearity=torch.nn.ReLU):
super(SimpleModel, self).__init__()
self.dropout_p = p
self.decay = decay
self.f = torch.nn.Sequential(
torch.nn.Linear(1,20),
torch.nn.ReLU(),
torch.nn.Dropout(p=self.dropout_p),
torch.nn.Linear(20, 20),
non_linearity(),
torch.nn.Dropout(p=self.dropout_p),
torch.nn.Linear(20,1)
)
def forward(self, X):
X = Variable(torch.Tensor(X), requires_grad=False)
return self.f(X)
def ewc_criterion(model, fisher_matrix, prior_params, alpha=1):
ewc_loss = Variable(torch.zeros(1))
for p, fm, prior in zip(model.parameters(), fisher_matrix, prior_params):
param_loss = torch.pow(p - Variable(prior, requires_grad=False), 2)
fisher_scaled_loss = alpha*Variable(fm, requires_grad=False).mul(param_loss)
ewc_loss += fisher_scaled_loss.sum()
return ewc_loss
def get_criterion_optimizer(model):
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(
model.parameters(),
#lr=0.001,
#momentum=0.5,
weight_decay=model.decay)
return criterion, optimizer
def plot_predictions(model, model_simple, ax=None, uncertainty=0, N=0, l2=0.005):
if ax is None:
fig, ax = plt.subplots(1,1)
ax.plot(X_obs[:10], y_obs[:10], ls="none", marker="o", color="b", label="task 1")
ax.plot(X_obs[10:20], y_obs[10:20], ls="none", marker="d", color="cyan", label="task 2")
ax.plot(X_obs[20:], y_obs[20:], ls="none", marker="s", color="magenta", label="task 3")
ax.plot(X_true, y_true, ls="-", lw=1, color="0.5", label="true")
for model_t, model_name, color in [
(model, "ewc", "r"),
(model_simple, "simple", "g")
]:
if uncertainty > 0:
model_t.train()
y_preds = []
for i in trange(uncertainty):
y_pred = model_t(X_true[:, np.newaxis]).data.numpy()
y_preds.append(y_pred)
y_preds = np.hstack(y_preds)
y_pred = y_preds.mean(axis=1)
y_var = y_preds.var(axis=1)
tau = l2 * (model_t.dropout_p) / (2*N*model_t.decay)
y_var += (1/tau)
y_std = np.sqrt(y_var) #+ (1/tau)
for i in range(4):
ax.fill_between(
X_true,
y_pred - y_std * ((i+1)/2),
y_pred + y_std * ((i+1)/2),
color=color,
alpha=0.1**(i+1)
)
else:
model_t.eval()
y_pred = model_t(X_true[:, np.newaxis]).data.numpy()
r_squared = r2_score(y_true, y_pred)
ax.plot(X_true, y_pred, ls="--", lw=1, color=color,
label="{} ({:.3f})".format(model_name, r_squared))
ax.legend()
ax.set_xlabel("X")
ax.set_ylabel("Y")
sns.despine(offset=10)
return ax
def plot_losses(losses, losses_simple, ax=None):
if ax is None:
fig, ax = plt.subplots(1,1,figsize=(10,4))
ax.plot(losses, lw=0.1, color='r', label='ewc')
ax.plot(losses_simple, lw=0.1, color='0.5', label='simple')
ax.legend()
ax.set_xlabel("Epochs")
ax.set_ylabel("Loss")
sns.despine(offset=10)
return ax
def get_fisher_matrix(model, criterion, X, Y, old_fisher_matrix=None):
fisher_matrix = [torch.zeros(p.data.shape) for p in model.parameters()]
counts = 0
#print(X.shape, Y.shape)
for x,y in tqdm(zip(X[:, np.newaxis, np.newaxis], Y[:, np.newaxis])):
criterion.zero_grad()
#print(x.shape, y.shape)
y_pred = model(x)
#print(y_pred.data.shape)
loss = criterion(y_pred, Variable(torch.Tensor(y), requires_grad=False))
loss.backward()
for i, p in enumerate(model.parameters()):
grad_squared = torch.pow(p.grad.data, 2)
#print(grad_squared.shape, fisher_matrix[i].shape)
fisher_matrix[i] += grad_squared
counts += 1
#print(grad_squared.shape, fisher_matrix[i].shape)
## Calculate the means
fisher_matrix = [k/counts for k in fisher_matrix]
if old_fisher_matrix:
fisher_matrix = [k + o for k,o in zip(fisher_matrix, old_fisher_matrix)]
return fisher_matrix
def train_step(
model, X, Y, criterion, optimizer,
fisher_matrix=None, model_params=None,
ecw_alpha=0, clip=10):
y_pred = model(X[:, np.newaxis])
loss = criterion(y_pred, Variable(torch.Tensor(Y), requires_grad=False))
train_loss = loss.data.numpy()[0]
if ecw_alpha > 0:
ewc_loss = ewc_criterion(model, fisher_matrix, model_params, alpha=ecw_alpha)
loss += ewc_loss
loss.backward()
torch.nn.utils.clip_grad_norm(model.parameters(),clip)
optimizer.step()
return train_loss
In [6]:
def train_plot_compare(X_obs, y_obs, prefix="IID", clip=10, uncertainty=0, l2=0.01, **kwargs):
model = SimpleModel(**kwargs)
model_simple = copy.deepcopy(model)
fig, ax = plt.subplots(1,3, sharex=True, sharey=True, figsize=(18, 6))
fig_loss, ax_loss = plt.subplots(3,1, sharex=True, sharey=False, figsize=(10, 12))
## Task 1
fisher_matrix = None
model_params = None
N = 0
for i, X, Y, alpha in zip(
range(3),
[X_obs[:10], X_obs[10:20], X_obs[20:]],
[y_obs[:10], y_obs[10:20], y_obs[20:]],
[0, 1e6, 1e6] # Alpha should increase with tasks
):
title = "Task {}".format(" + ".join(str(k+1) for k in range(i+1)))
model.train()
criterion, optimizer = get_criterion_optimizer(model)
model_simple.train()
criterion_simple, optimizer_simple = get_criterion_optimizer(model_simple)
losses = []
losses_simple = []
for epoch in trange(10000):
loss = train_step(
model, X, Y,
criterion, optimizer,
fisher_matrix=fisher_matrix,
model_params=model_params,
ecw_alpha=alpha, clip=clip
)
losses.append(loss)
loss = train_step(
model_simple, X, Y, criterion_simple, optimizer_simple, clip=clip)
losses_simple.append(loss)
plot_losses(losses, losses_simple, ax=ax_loss[i])
N = X.shape[0]
plot_predictions(model, model_simple, ax=ax[i], uncertainty=uncertainty, N=N, l2=l2)
ax_loss[i].set_title(title)
ax[i].set_title(title)
print(title)
model.eval()
fisher_matrix = get_fisher_matrix(model, criterion, X, Y, old_fisher_matrix=fisher_matrix)
model_params = [copy.deepcopy(p.data) for p in model.parameters()]
fig.tight_layout()
fig_loss.tight_layout()
sns.despine(offset=10, fig=fig)
sns.despine(offset=10, fig=fig_loss)
fig.savefig("{}_EWC_predictions.pdf".format(prefix), bbox_inches="tight")
fig.savefig("{}_EWC_losses.pdf".format(prefix), bbox_inches="tight")
return model, model_simple
In [7]:
plt.plot(X_obs[:10], y_obs[:10], ls="none", marker="o", color="b", label="task 1")
plt.plot(X_obs[10:20], y_obs[10:20], ls="none", marker="d", color="cyan", label="task 2")
plt.plot(X_obs[20:], y_obs[20:], ls="none", marker="s", color="magenta", label="task 3")
plt.plot(X_true, y_true, ls="-", lw=1, color="0.5", label="true")
plt.legend()
sns.despine(offset=10)
In [8]:
model, model_simple = train_plot_compare(
X_obs, y_obs, prefix="IID",
p=0.1, decay=1e-6, non_linearity=torch.nn.ReLU)
In [9]:
"""
model.eval()
fisher_matrix = get_fisher_matrix(model, criterion, X_obs[:10], y_obs[:10])
model_params = [copy.deepcopy(p.data) for p in model.parameters()]
fig = plt.figure(0)
ax = fig.add_subplot(111)
el = plt.matplotlib.patches.Ellipse(
xy=[
model_params[0].squeeze().numpy()[0],
model_params[1].numpy()[0]
],
width=2./fisher_matrix[0].squeeze().numpy()[0],
height=2./fisher_matrix[1].numpy()[0])
ax.add_artist(el)
#el.set_clip_box(ax.bbox)
el.set_color('0.5')
el.set_alpha(0.1)
ax.scatter(model_params[0].squeeze().numpy(), model_params[1].numpy(), s=1)
"""
Out[9]:
In [10]:
sorted_idx = np.argsort(X_obs)
X_obs = X_obs[sorted_idx]
y_obs = y_obs[sorted_idx]
In [11]:
plt.plot(X_obs[:10], y_obs[:10], ls="none", marker="o", color="b", label="task 1")
plt.plot(X_obs[10:20], y_obs[10:20], ls="none", marker="d", color="cyan", label="task 2")
plt.plot(X_obs[20:], y_obs[20:], ls="none", marker="s", color="magenta", label="task 3")
plt.plot(X_true, y_true, ls="-", lw=1, color="0.5", label="true")
plt.legend()
sns.despine(offset=10)
In [12]:
model, model_simple = train_plot_compare(
X_obs, y_obs, prefix="Seq",
p=0.1, decay=1e-6, non_linearity=torch.nn.ReLU)
In [13]:
model, model_simple = train_plot_compare(
X_obs, y_obs, prefix="Seq", uncertainty=100, l2=0.01,
p=0.1, decay=1e-6, non_linearity=torch.nn.ReLU
)
In [ ]: