Taken from http://mlg.eng.cam.ac.uk/yarin/blog_3d801aa532c1ce.html#uncertainty-sense
Also see: https://alexgkendall.com/computer_vision/bayesian_deep_learning_for_safe_ai/
In [1]:
%matplotlib inline
import numpy as np
import torch
from torch.autograd import Variable
from matplotlib import animation
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm, trange
from ipywidgets import interact, fixed
from sklearn.metrics import r2_score
from IPython.display import display, HTML
In [2]:
sns.set_context("poster")
sns.set_style("ticks")
np.random.seed(1337)
In [3]:
def get_data(N, min_x, max_x):
w, b = np.random.rand(2)
def true_model(X):
y = np.sin(np.exp(w * X + b))
return y
X_true = np.linspace(min_x, max_x, N)
y_true = true_model(X_true)
span = (max_x - min_x)
scale = 0.25
X_obs = min_x + span*scale + np.random.rand(N)*(span - 2*scale*span)
y_obs = true_model(X_obs) + np.random.randn(N)*0.2
X_mean, X_std = X_obs.mean(), X_obs.std()
y_mean, y_std = y_obs.mean(), y_obs.std()
X_obs = (X_obs - X_mean)/ X_std
X_true = (X_true - X_mean)/X_std
y_obs = (y_obs - y_mean)/y_std
y_true = (y_true - y_mean)/y_std
return (X_obs, y_obs, X_true, y_true), (w, b, true_model)
In [4]:
N = 100
min_x, max_x = -20, 20
(X_obs, y_obs, X_true, y_true), (w, b, true_model) = get_data(N, min_x, max_x)
In [5]:
plt.plot(X_obs, y_obs, ls="none", marker="o", color="k", label="observed")
#plt.plot(X_true, y_true, ls="-", color="r", label="true", alpha=0.5)
plt.legend()
sns.despine(offset=10)
In [6]:
plt.plot(X_obs, y_obs, ls="none", marker="o", color="k", label="observed")
plt.plot(X_true, y_true, ls="-", color="r", label="true", alpha=0.5)
plt.legend()
sns.despine(offset=10)
In [7]:
class SimpleModel(torch.nn.Module):
def __init__(self, p=0.05, decay=0.001, non_linearity=torch.nn.ReLU):
super(SimpleModel, self).__init__()
self.dropout_p = p
self.decay = decay
self.f = torch.nn.Sequential(
torch.nn.Linear(1,20),
torch.nn.ReLU(),
torch.nn.Dropout(p=self.dropout_p),
torch.nn.Linear(20, 20),
non_linearity(),
torch.nn.Dropout(p=self.dropout_p),
torch.nn.Linear(20,1)
)
def forward(self, X):
X = Variable(torch.Tensor(X), requires_grad=False)
return self.f(X)
In [8]:
model = SimpleModel(p=0.1, decay = 1e-6, non_linearity=torch.nn.Sigmoid)
In [9]:
def uncertainity_estimate(X, model, iters, N, l2=0.005, range_fn=range):
outputs = np.hstack([model(X[:, np.newaxis]).data.numpy() for i in range_fn(iters)])
y_mean = outputs.mean(axis=1)
y_variance = outputs.var(axis=1)
tau = l2 * (1-model.dropout_p) / (2*N*model.decay)
y_variance += (1/tau)
y_std = np.sqrt(y_variance)
return y_mean, y_std
In [10]:
y_mean, y_std = uncertainity_estimate(X_true, model, 200, N=1)
In [11]:
def plot_model(model, selected_idx, iters=200, l2=1, n_std=3, ax=None):
if ax is None:
plt.close("all")
plt.clf()
fig, ax = plt.subplots(1,1)
N = (selected_idx * 1).sum()
y_mean, y_std = uncertainity_estimate(X_true, model, N=N, iters=iters, l2=l2)
ax.plot(X_true, y_true, ls="-", color="0.1", label="true", alpha=0.5)
ax.plot(X_true, y_mean, ls="-", color="b", label="mean")
for i in range(n_std):
ax.fill_between(
X_true,
y_mean - y_std * ((i+1)/2),
y_mean + y_std * ((i+1)/2),
color="b",
alpha=0.1
)
y_mean, y_std = uncertainity_estimate(X_obs, model, N=N, iters=iters, l2=l2)
next_idx = y_std[~selected_idx].argmax()
next_idx = np.where(~selected_idx)[0][next_idx]
R_score = r2_score(y_obs, y_mean)
ax.set_title("R2={:.3f}, N={}, Next={}".format(R_score, N, next_idx))
ax.plot(
X_obs[~selected_idx],
y_obs[~selected_idx],
ls="none", marker="o",
color="0.5", alpha=0.5,
label="observed")
ax.plot(
X_obs[selected_idx],
y_obs[selected_idx],
ls="none", marker="*",
color="purple", alpha=0.5,
label="selected")
ax.plot(
X_obs[next_idx],
y_obs[next_idx],
ls="none", marker="*",
color="r", alpha=0.5,
ms=30,
label="next")
ax.legend()
sns.despine(offset=10)
selected_idx[next_idx] = True
return ax, selected_idx, next_idx
In [12]:
selected_idx = np.full_like(X_obs, False, dtype=np.bool)
ax, selected_idx, next_idx = plot_model(model, selected_idx, n_std=3, l2=0.01)
In [13]:
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(
model.parameters(), lr=0.001, momentum=0.9,
weight_decay=model.decay)
def fit_model(model, selected_idx, optimizer):
y = Variable(
torch.Tensor(
y_obs[selected_idx, np.newaxis]),
requires_grad=False
)
y_pred = model(X_obs[selected_idx, np.newaxis])
optimizer.zero_grad()
loss = criterion(y_pred, y)
loss.backward()
optimizer.step()
return loss
In [14]:
fig = plt.figure(figsize=(10, 15))
ax0 = plt.subplot2grid((3,1), (0, 0), rowspan=2)
ax1 = plt.subplot2grid((3,1), (2, 0))
losses = []
for i in trange(1000):
loss = fit_model(model, selected_idx, optimizer)
losses.append(loss.data.numpy()[0])
print("loss={}".format(loss))
ax1.plot(losses, ls="-", lw=1, alpha=0.5)
ax0, selected_idx, next_idx = plot_model(model, selected_idx, ax=ax0, l2=1)
In [15]:
class AnimateTraining(object):
def __init__(self, model, u_iters=200, l2=1, n_std=4, title=""):
self.model = model
self.criterion = torch.nn.MSELoss()
self.optimizer = torch.optim.Adam(
model.parameters(),
weight_decay=model.decay)
self.losses = []
self.n_std = n_std
self.u_iters = u_iters
self.l2 = l2
self.title = title
self.selected_idx = np.full_like(X_obs, False, dtype=np.bool)
self.selected_idx[np.random.randint(X_obs.shape[0])] = True
## plot items
self.fig, self.ax0 = plt.subplots(1,1)
self.pts_obs, = self.ax0.plot(
[], [],
ls="none", marker="o",
color="0.1", alpha=0.5, label="observed"
)
self.pts_selected, = self.ax0.plot(
[], [],
ls="none", marker="*",
color="purple", alpha=0.5,
label="selected"
)
self.pts_next, = self.ax0.plot(
[], [],
ls="none", marker="*",
color="r", alpha=0.5,
ms=30,
label="next"
)
self.ax0.plot(X_true, y_true, ls="-", color="0.1", label="true")
self.ln_mean, = self.ax0.plot([], [], ls="-", color="b", label="mean")
self.loss_text = self.ax0.set_title('', fontsize=15)
self.fill_stds = []
for i in range(self.n_std):
fill_t = self.ax0.fill_between(
[], [], [],
color="b",
alpha=0.5**(i+1)
)
self.fill_stds.append(fill_t)
self.ax0.legend(loc="upper left")
def query_next(self):
N = (selected_idx * 1).sum()
y_mean, y_std = uncertainity_estimate(
X_obs, self.model, N=N, iters=self.u_iters, l2=self.l2)
next_idx = y_std[~self.selected_idx].argmax()
next_idx = np.where(~self.selected_idx)[0][next_idx]
return next_idx
def fit_model(self):
y = Variable(
torch.Tensor(
y_obs[self.selected_idx, np.newaxis]),
requires_grad=False
)
y_pred = self.model(X_obs[self.selected_idx, np.newaxis])
self.optimizer.zero_grad()
loss = self.criterion(y_pred, y)
loss.backward()
self.optimizer.step()
return loss
def init_plot(self):
self.ln_mean.set_data([], [])
self.loss_text.set_text('')
return self.ln_mean, self.loss_text
def animate_plot(self, i):
for j in range(100):
loss = self.fit_model().data.numpy()[0]
self.losses.append(loss)
N = (self.selected_idx*1).sum()
y_mean, y_std = uncertainity_estimate(
X_true, self.model, self.u_iters,
N=N,
l2=self.l2,
range_fn=range
)
self.ln_mean.set_data(X_true, y_mean)
for std_i in range(self.n_std):
self.fill_stds[std_i].remove()
self.fill_stds[std_i] = self.ax0.fill_between(
X_true,
y_mean - y_std * ((std_i+1)/2),
y_mean + y_std * ((std_i+1)/2),
color="b",
alpha=0.1
)
R_score = r2_score(y_obs, y_mean)
next_idx = None
if N < X_obs.shape[0]:
next_idx = self.query_next()
self.pts_obs.set_data(
X_obs[~self.selected_idx],
y_obs[~self.selected_idx]
)
self.pts_next.set_data(
X_obs[next_idx],
y_obs[next_idx],
)
self.loss_text.set_text("{}, loss[{}]={:.3f}, R2={:.3f}, N={}, Next={}".format(
self.title, (i+1)*100, loss, R_score, N, next_idx))
self.pts_selected.set_data(
X_obs[self.selected_idx],
y_obs[self.selected_idx],
)
if next_idx:
self.selected_idx[next_idx] = True
return ([
self.ln_mean,
self.pts_obs,
self.pts_selected,
self.pts_next,
self.loss_text
] + self.fill_stds)
def train(self, iters, interval=100):
anim = animation.FuncAnimation(
self.fig, self.animate_plot, init_func=self.init_plot,
frames=range(iters), interval=interval, blit=True)
return HTML(anim.to_html5_video())
In [16]:
model = SimpleModel(p=0.1, decay = 1e-6, non_linearity=torch.nn.Sigmoid)
animate_obj = AnimateTraining(model, l2=0.01, title="Simple")
In [17]:
animate_obj.train(150, interval=150)
Out[17]:
In [18]:
for i, (non_linearity, title) in enumerate([
(torch.nn.Sigmoid, "Sigmoid"),
(torch.nn.ReLU, "ReLU"),
(torch.nn.Tanh, "Tanh"),
(torch.nn.Softsign, "Softsign"),
(torch.nn.Softshrink, "Softshrink"),
(torch.nn.Softplus, "Softplus")
]):
display(HTML("<h1>{}</h1>".format(title)))
model = SimpleModel(p=0.1, decay = 1e-6, non_linearity=non_linearity)
animate_obj = AnimateTraining(model, l2=0.01, title=title)
display(animate_obj.train(150, interval=150))