In [1]:
%matplotlib inline

import numpy as np
import torch
from torch.autograd import Variable

import matplotlib.pyplot as plt
import seaborn as sns

from tqdm import tqdm, trange
from ipywidgets import interact, fixed

In [2]:
sns.set_context("poster")
sns.set_style("ticks")

np.random.seed(1337)

In [3]:
def get_data(N, min_x, max_x):
    w, b = np.random.rand(2)
    def true_model(X):
        y = np.sin(np.exp(w * X + b))
        return y
    X_true = np.linspace(min_x, max_x, N)
    y_true = true_model(X_true)
    span = (max_x - min_x)
    scale = 0.25
    X_obs = min_x + span*scale + np.random.rand(N)*(span - 2*scale*span)
    y_obs = true_model(X_obs) + np.random.randn(N)*0.2
    
    X_mean, X_std = X_obs.mean(), X_obs.std()
    y_mean, y_std = y_obs.mean(), y_obs.std()
    
    X_obs = (X_obs - X_mean)/ X_std
    X_true = (X_true - X_mean)/X_std
    
    y_obs = (y_obs - y_mean)/y_std
    y_true = (y_true - y_mean)/y_std
    
    return (X_obs, y_obs, X_true, y_true), (w, b, true_model)

In [4]:
N = 100
min_x, max_x = -20, 20
(X_obs, y_obs, X_true, y_true), (w, b, true_model) = get_data(N, min_x, max_x)

In [5]:
plt.plot(X_obs, y_obs, ls="none", marker="o", color="k", label="observed")
#plt.plot(X_true, y_true, ls="-", color="r", label="true", alpha=0.5)
plt.legend()
sns.despine(offset=10)


/home/napsternxg/anaconda3/lib/python3.5/site-packages/matplotlib/font_manager.py:1297: UserWarning: findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans
  (prop.get_family(), self.defaultFamily[fontext]))

In [6]:
plt.plot(X_obs, y_obs, ls="none", marker="o", color="k", label="observed")
plt.plot(X_true, y_true, ls="-", color="r", label="true", alpha=0.5)
plt.legend()
sns.despine(offset=10)


/home/napsternxg/anaconda3/lib/python3.5/site-packages/matplotlib/font_manager.py:1297: UserWarning: findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans
  (prop.get_family(), self.defaultFamily[fontext]))

In [7]:
class SimpleModel(torch.nn.Module):
    def __init__(self, p=0.05, decay=0.001, non_linearity=torch.nn.ReLU):
        super(SimpleModel, self).__init__()
        self.dropout_p = p
        self.decay = decay
        self.f = torch.nn.Sequential(
            torch.nn.Linear(1,20),
            torch.nn.ReLU(),
            torch.nn.Dropout(p=self.dropout_p),
            torch.nn.Linear(20, 20),
            non_linearity(),
            torch.nn.Dropout(p=self.dropout_p),
            torch.nn.Linear(20,1)
        )
    def forward(self, X):
        X = Variable(torch.Tensor(X), requires_grad=False)
        return self.f(X)

In [8]:
model = SimpleModel(p=0.1, decay = 1e-6)

In [9]:
def uncertainity_estimate(X, model, iters, l2=0.005):
    outputs = np.hstack([model(X[:, np.newaxis]).data.numpy() for i in trange(iters)])
    y_mean = outputs.mean(axis=1)
    y_variance = outputs.var(axis=1)
    tau = l2 * (1-model.dropout_p) / (2*N*model.decay)
    y_variance += (1/tau)
    y_std = np.sqrt(y_variance)
    return y_mean, y_std

In [10]:
y_mean, y_std = uncertainity_estimate(X_true, model, 200)


100%|██████████| 200/200 [00:00<00:00, 2596.00it/s]

In [11]:
def plot_model(model, iters=200, l2=1, n_std=3, ax=None):
    if ax is None:
        plt.close("all")
        plt.clf()
        fig, ax = plt.subplots(1,1)
    y_mean, y_std = uncertainity_estimate(X_true, model, iters, l2=l2)
    
    ax.plot(X_obs, y_obs, ls="none", marker="o", color="0.1", alpha=0.5, label="observed")
    ax.plot(X_true, y_true, ls="-", color="r", label="true", alpha=0.5)
    ax.plot(X_true, y_mean, ls="-", color="b", label="mean")
    for i in range(n_std):
        ax.fill_between(
            X_true,
            y_mean - y_std * ((i+1)/2),
            y_mean + y_std * ((i+1)/2),
            color="b",
            alpha=0.1
        )
    ax.legend()
    sns.despine(offset=10)
    return ax

In [12]:
plot_model(model, n_std=3, l2=0.01)


100%|██████████| 200/200 [00:00<00:00, 2711.03it/s]
Out[12]:
<matplotlib.axes._subplots.AxesSubplot at 0x7f9c29dd0b38>
<matplotlib.figure.Figure at 0x7f9c29e58748>
/home/napsternxg/anaconda3/lib/python3.5/site-packages/matplotlib/font_manager.py:1297: UserWarning: findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans
  (prop.get_family(), self.defaultFamily[fontext]))

In [13]:
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(
    model.parameters(), lr=0.001, momentum=0.9,
    weight_decay=model.decay)

In [14]:
def fit_model(model, optimizer):
    y = Variable(torch.Tensor(y_obs[:, np.newaxis]), requires_grad=False)
    y_pred = model(X_obs[:, np.newaxis])
    optimizer.zero_grad()
    loss = criterion(y_pred, y)
    loss.backward()
    optimizer.step()
    return loss

In [15]:
fig = plt.figure(figsize=(10, 15))
ax0 = plt.subplot2grid((3,1), (0, 0), rowspan=2)
ax1 = plt.subplot2grid((3,1), (2, 0))
losses = []
for i in trange(10000):
    loss = fit_model(model, optimizer)
    losses.append(loss.data.numpy()[0])
print("loss={}".format(loss))
ax1.plot(losses, ls="-", lw=1, alpha=0.5)
plot_model(model, ax=ax0, l2=1)


100%|██████████| 10000/10000 [00:10<00:00, 932.08it/s]
100%|██████████| 200/200 [00:00<00:00, 2943.04it/s]
loss=Variable containing:
 0.2617
[torch.FloatTensor of size 1]


Out[15]:
<matplotlib.axes._subplots.AxesSubplot at 0x7f9c29cc66d8>
/home/napsternxg/anaconda3/lib/python3.5/site-packages/matplotlib/font_manager.py:1297: UserWarning: findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans
  (prop.get_family(), self.defaultFamily[fontext]))

In [16]:
fig, ax = plt.subplots(1,1)
model.eval()
y_pred = model(X_true[:, np.newaxis]).data.numpy()
ax.plot(X_obs, y_obs, ls="none", marker="o", color="0.1", alpha=0.5, label="observed")
ax.plot(X_true, y_true, ls="-", color="r", label="true", alpha=0.5)
ax.plot(X_true, y_pred, ls="-", color="b", label="predicted")
ax.legend()
sns.despine(offset=10)


/home/napsternxg/anaconda3/lib/python3.5/site-packages/matplotlib/font_manager.py:1297: UserWarning: findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans
  (prop.get_family(), self.defaultFamily[fontext]))

In [17]:
fig, ax = plt.subplots(1,1)
model.eval()
y_pred = model(X_true[:, np.newaxis]).data.numpy()
ax.plot(X_obs, y_obs, ls="none", marker="o", color="0.1", alpha=0.5, label="observed")
#ax.plot(X_true, y_true, ls="-", color="r", label="true", alpha=0.5)
ax.plot(X_true, y_pred, ls="-", color="b", label="predicted")
ax.legend()
sns.despine(offset=10)


/home/napsternxg/anaconda3/lib/python3.5/site-packages/matplotlib/font_manager.py:1297: UserWarning: findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans
  (prop.get_family(), self.defaultFamily[fontext]))

Compare non-linearities


In [18]:
fig, ax = plt.subplots(2,3, sharex=True, sharey=False, figsize=(24,16))

ax = ax.flatten()
for i, (non_linearity, title) in enumerate([
    (torch.nn.Sigmoid, "Sigmoid"),
    (torch.nn.ReLU, "ReLU"),
    (torch.nn.Tanh, "Tanh"),
    (torch.nn.Softsign, "Softsign"),
    (torch.nn.Softshrink, "Softshrink"),
    (torch.nn.Softplus, "Softplus")
]):
    model = SimpleModel(p=0.1, decay = 1e-6, non_linearity=non_linearity)
    criterion = torch.nn.MSELoss()
    """optimizer = torch.optim.SGD(
        model.parameters(), lr=0.001, momentum=0.9,
        weight_decay=model.decay)"""
    optimizer = torch.optim.Adam(
        model.parameters(), #lr=0.001, momentum=0.9,
        weight_decay=model.decay)
    losses = []
    for epoch in trange(10000):
        loss = fit_model(model, optimizer)
        losses.append(loss.data.numpy()[0])
    loss = loss.data.numpy()[0]
    print("{}: loss={}".format(title, loss))
    plot_model(model, ax=ax[i], l2=1)
    ax[i].set_title("{}, Loss: {:.3f}".format(title, loss))

sns.despine(offset=10)
fig.tight_layout()


100%|██████████| 10000/10000 [00:11<00:00, 853.68it/s]
100%|██████████| 200/200 [00:00<00:00, 2972.33it/s]
  0%|          | 0/10000 [00:00<?, ?it/s]
Sigmoid: loss=0.24633952975273132
100%|██████████| 10000/10000 [00:11<00:00, 879.96it/s]
100%|██████████| 200/200 [00:00<00:00, 2973.27it/s]
  0%|          | 0/10000 [00:00<?, ?it/s]
ReLU: loss=0.23126958310604095
100%|██████████| 10000/10000 [00:11<00:00, 843.58it/s]
100%|██████████| 200/200 [00:00<00:00, 2697.24it/s]
  0%|          | 0/10000 [00:00<?, ?it/s]
Tanh: loss=0.2576611042022705
100%|██████████| 10000/10000 [00:11<00:00, 892.36it/s]
100%|██████████| 200/200 [00:00<00:00, 3062.20it/s]
  0%|          | 0/10000 [00:00<?, ?it/s]
Softsign: loss=0.27574217319488525
100%|██████████| 10000/10000 [00:11<00:00, 872.45it/s]
100%|██████████| 200/200 [00:00<00:00, 3020.76it/s]
  0%|          | 0/10000 [00:00<?, ?it/s]
Softshrink: loss=0.27915409207344055
100%|██████████| 10000/10000 [00:12<00:00, 788.27it/s]
100%|██████████| 200/200 [00:00<00:00, 2425.06it/s]
Softplus: loss=0.2493857890367508
/home/napsternxg/anaconda3/lib/python3.5/site-packages/matplotlib/font_manager.py:1297: UserWarning: findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans
  (prop.get_family(), self.defaultFamily[fontext]))

In [ ]: