Scalable Additive-Structure GP Classification (CUDA) (w/ KISS-GP)

Introduction

This example shows how to use a AdditiveGridInducingVariationalGP module. This classifcation module is designed for when the function you’re modeling has an additive decomposition over dimension. This is equivalent to using a covariance function that additively decomposes over dimensions:

$$k(\mathbf{x},\mathbf{x'}) = \sum_{i=1}^{d}k([\mathbf{x}]_{i}, [\mathbf{x'}]_{i})$$

where $[\mathbf{x}]_{i}$ denotes the ith component of the vector $\mathbf{x}$. Example applications of this include use in Bayesian optimization, and when performing deep kernel learning.

The use of inducing points allows for scaling up the training data by making computational complexity linear instead of cubic in the number of data points.

In this example, we’re performing classification on a two dimensional toy dataset that is:

  • Defined in [-1, 1]x[-1, 1]
  • Valued 1 in [-0.5, 0.5]x[-0.5, 0.5]
  • Valued -1 otherwise

The above function doesn't have an obvious additive decomposition, but it turns out that this function is can be very well approximated by the kernel anyways.


In [1]:
# High-level imports
import math
from math import exp
import torch
import gpytorch
from matplotlib import pyplot as plt

# Make inline plots
%matplotlib inline


/home/gpleiss/anaconda3/envs/gpytorch/lib/python3.7/site-packages/matplotlib/__init__.py:999: UserWarning: Duplicate key in file "/home/gpleiss/.dotfiles/matplotlib/matplotlibrc", line #57
  (fname, cnt))

Generate toy dataset


In [2]:
n = 101
train_x = torch.zeros(n ** 2, 2)
train_x[:, 0].copy_(torch.linspace(-1, 1, n).repeat(n))
train_x[:, 1].copy_(torch.linspace(-1, 1, n).unsqueeze(1).repeat(1, n).view(-1))
train_y = (train_x[:, 0].abs().lt(0.5)).float() * (train_x[:, 1].abs().lt(0.5)).float() * 2 - 1

train_x = train_x.cuda()
train_y = train_y.cuda()

Define the model

In contrast to the most basic classification models, this model uses an AdditiveGridInterpolationVariationalStrategy. This causes two key changes in the model. First, the model now specifically assumes that the input to forward, x, is to be additive decomposed. Thus, although the model below defines an RBFKernel as the covariance function, because we extend this base class, the additive decomposition discussed above will be imposed.

Second, this model automatically assumes we will be using scalable kernel interpolation (SKI) for each dimension. Because of the additive decomposition, we only provide one set of grid bounds to the base class constructor, as the same grid will be used for all dimensions. It is recommended that you scale your training and test data appropriately.


In [3]:
from gpytorch.models import AbstractVariationalGP
from gpytorch.variational import AdditiveGridInterpolationVariationalStrategy, CholeskyVariationalDistribution
from gpytorch.kernels import RBFKernel, ScaleKernel
from gpytorch.likelihoods import BernoulliLikelihood
from gpytorch.means import ConstantMean
from gpytorch.distributions import MultivariateNormal

class GPClassificationModel(AbstractVariationalGP):
    def __init__(self, grid_size=128, grid_bounds=([-1, 1],)):
        variational_distribution = CholeskyVariationalDistribution(num_inducing_points=grid_size, batch_size=2)
        variational_strategy = AdditiveGridInterpolationVariationalStrategy(self,
                                                                            grid_size=grid_size,
                                                                            grid_bounds=grid_bounds,
                                                                            num_dim=2,
                                                                            variational_distribution=variational_distribution)
        super(GPClassificationModel, self).__init__(variational_strategy)
        self.mean_module = ConstantMean()
        self.covar_module = ScaleKernel(RBFKernel(ard_num_dims=1))

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        latent_pred = MultivariateNormal(mean_x, covar_x)
        return latent_pred

# Cuda the model and likelihood function
model = GPClassificationModel().cuda()
likelihood = gpytorch.likelihoods.BernoulliLikelihood().cuda()

Training the model

Once the model has been defined, the training loop looks very similar to other variational models we've seen in the past. We will optimize the variational lower bound as our objective function. In this case, although variational inference in GPyTorch supports stochastic gradient descent, we choose to do batch optimization due to the relatively small toy dataset.

For an example of using the AdditiveGridInducingVariationalGP model with stochastic gradient descent, see the dkl_mnist example.


In [4]:
# Find optimal model hyperparameters
model.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# "Loss" for GPs - the marginal log likelihood
# n_data refers to the amount of training data
mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=train_y.numel())

# Training function
def train(num_iter=200):
    for i in range(num_iter):
        optimizer.zero_grad()
        output = model(train_x)
        loss = -mll(output, train_y)
        loss.backward()
        print('Iter %d/%d - Loss: %.3f' % (i + 1, num_iter, loss.item()))
        optimizer.step()

%time train()


Iter 1/200 - Loss: 1.421
Iter 2/200 - Loss: 1.415
Iter 3/200 - Loss: 0.863
Iter 4/200 - Loss: 0.908
Iter 5/200 - Loss: 1.154
Iter 6/200 - Loss: 1.289
Iter 7/200 - Loss: 1.069
Iter 8/200 - Loss: 1.341
Iter 9/200 - Loss: 1.304
Iter 10/200 - Loss: 1.014
Iter 11/200 - Loss: 1.087
Iter 12/200 - Loss: 1.293
Iter 13/200 - Loss: 1.176
Iter 14/200 - Loss: 1.166
Iter 15/200 - Loss: 1.245
Iter 16/200 - Loss: 1.250
Iter 17/200 - Loss: 1.023
Iter 18/200 - Loss: 0.926
Iter 19/200 - Loss: 1.113
Iter 20/200 - Loss: 1.115
Iter 21/200 - Loss: 0.861
Iter 22/200 - Loss: 0.720
Iter 23/200 - Loss: 0.677
Iter 24/200 - Loss: 0.827
Iter 25/200 - Loss: 0.709
Iter 26/200 - Loss: 0.697
Iter 27/200 - Loss: 0.928
Iter 28/200 - Loss: 0.917
Iter 29/200 - Loss: 0.968
Iter 30/200 - Loss: 1.099
Iter 31/200 - Loss: 1.112
Iter 32/200 - Loss: 0.924
Iter 33/200 - Loss: 0.622
Iter 34/200 - Loss: 0.760
Iter 35/200 - Loss: 0.645
Iter 36/200 - Loss: 1.126
Iter 37/200 - Loss: 0.654
Iter 38/200 - Loss: 0.606
Iter 39/200 - Loss: 0.632
Iter 40/200 - Loss: 0.706
Iter 41/200 - Loss: 0.568
Iter 42/200 - Loss: 0.610
Iter 43/200 - Loss: 0.573
Iter 44/200 - Loss: 0.592
Iter 45/200 - Loss: 0.626
Iter 46/200 - Loss: 0.614
Iter 47/200 - Loss: 0.729
Iter 48/200 - Loss: 0.602
Iter 49/200 - Loss: 0.640
Iter 50/200 - Loss: 0.649
Iter 51/200 - Loss: 0.614
Iter 52/200 - Loss: 0.491
Iter 53/200 - Loss: 0.559
Iter 54/200 - Loss: 0.482
Iter 55/200 - Loss: 0.491
Iter 56/200 - Loss: 0.576
Iter 57/200 - Loss: 0.500
Iter 58/200 - Loss: 0.502
Iter 59/200 - Loss: 0.577
Iter 60/200 - Loss: 0.508
Iter 61/200 - Loss: 0.538
Iter 62/200 - Loss: 0.486
Iter 63/200 - Loss: 0.503
Iter 64/200 - Loss: 0.565
Iter 65/200 - Loss: 0.491
Iter 66/200 - Loss: 0.451
Iter 67/200 - Loss: 0.406
Iter 68/200 - Loss: 0.513
Iter 69/200 - Loss: 0.461
Iter 70/200 - Loss: 0.570
Iter 71/200 - Loss: 0.460
Iter 72/200 - Loss: 0.430
Iter 73/200 - Loss: 0.502
Iter 74/200 - Loss: 0.475
Iter 75/200 - Loss: 0.397
Iter 76/200 - Loss: 0.552
Iter 77/200 - Loss: 0.458
Iter 78/200 - Loss: 0.390
Iter 79/200 - Loss: 0.429
Iter 80/200 - Loss: 0.384
Iter 81/200 - Loss: 0.409
Iter 82/200 - Loss: 0.377
Iter 83/200 - Loss: 0.389
Iter 84/200 - Loss: 0.460
Iter 85/200 - Loss: 0.430
Iter 86/200 - Loss: 0.385
Iter 87/200 - Loss: 0.345
Iter 88/200 - Loss: 0.422
Iter 89/200 - Loss: 0.384
Iter 90/200 - Loss: 0.408
Iter 91/200 - Loss: 0.328
Iter 92/200 - Loss: 0.367
Iter 93/200 - Loss: 0.374
Iter 94/200 - Loss: 0.378
Iter 95/200 - Loss: 0.401
Iter 96/200 - Loss: 0.379
Iter 97/200 - Loss: 0.398
Iter 98/200 - Loss: 0.350
Iter 99/200 - Loss: 0.363
Iter 100/200 - Loss: 0.348
Iter 101/200 - Loss: 0.380
Iter 102/200 - Loss: 0.355
Iter 103/200 - Loss: 0.354
Iter 104/200 - Loss: 0.362
Iter 105/200 - Loss: 0.334
Iter 106/200 - Loss: 0.346
Iter 107/200 - Loss: 0.362
Iter 108/200 - Loss: 0.325
Iter 109/200 - Loss: 0.343
Iter 110/200 - Loss: 0.378
Iter 111/200 - Loss: 0.339
Iter 112/200 - Loss: 0.325
Iter 113/200 - Loss: 0.325
Iter 114/200 - Loss: 0.324
Iter 115/200 - Loss: 0.317
Iter 116/200 - Loss: 0.335
Iter 117/200 - Loss: 0.340
Iter 118/200 - Loss: 0.340
Iter 119/200 - Loss: 0.302
Iter 120/200 - Loss: 0.307
Iter 121/200 - Loss: 0.300
Iter 122/200 - Loss: 0.308
Iter 123/200 - Loss: 0.301
Iter 124/200 - Loss: 0.315
Iter 125/200 - Loss: 0.302
Iter 126/200 - Loss: 0.297
Iter 127/200 - Loss: 0.305
Iter 128/200 - Loss: 0.296
Iter 129/200 - Loss: 0.291
Iter 130/200 - Loss: 0.310
Iter 131/200 - Loss: 0.303
Iter 132/200 - Loss: 0.297
Iter 133/200 - Loss: 0.297
Iter 134/200 - Loss: 0.275
Iter 135/200 - Loss: 0.293
Iter 136/200 - Loss: 0.279
Iter 137/200 - Loss: 0.313
Iter 138/200 - Loss: 0.262
Iter 139/200 - Loss: 0.284
Iter 140/200 - Loss: 0.261
Iter 141/200 - Loss: 0.256
Iter 142/200 - Loss: 0.260
Iter 143/200 - Loss: 0.286
Iter 144/200 - Loss: 0.260
Iter 145/200 - Loss: 0.290
Iter 146/200 - Loss: 0.279
Iter 147/200 - Loss: 0.262
Iter 148/200 - Loss: 0.254
Iter 149/200 - Loss: 0.263
Iter 150/200 - Loss: 0.270
Iter 151/200 - Loss: 0.283
Iter 152/200 - Loss: 0.257
Iter 153/200 - Loss: 0.269
Iter 154/200 - Loss: 0.250
Iter 155/200 - Loss: 0.250
Iter 156/200 - Loss: 0.259
Iter 157/200 - Loss: 0.265
Iter 158/200 - Loss: 0.245
Iter 159/200 - Loss: 0.259
Iter 160/200 - Loss: 0.262
Iter 161/200 - Loss: 0.262
Iter 162/200 - Loss: 0.260
Iter 163/200 - Loss: 0.254
Iter 164/200 - Loss: 0.240
Iter 165/200 - Loss: 0.230
Iter 166/200 - Loss: 0.260
Iter 167/200 - Loss: 0.243
Iter 168/200 - Loss: 0.255
Iter 169/200 - Loss: 0.232
Iter 170/200 - Loss: 0.245
Iter 171/200 - Loss: 0.232
Iter 172/200 - Loss: 0.243
Iter 173/200 - Loss: 0.239
Iter 174/200 - Loss: 0.257
Iter 175/200 - Loss: 0.243
Iter 176/200 - Loss: 0.234
Iter 177/200 - Loss: 0.234
Iter 178/200 - Loss: 0.240
Iter 179/200 - Loss: 0.232
Iter 180/200 - Loss: 0.230
Iter 181/200 - Loss: 0.235
Iter 182/200 - Loss: 0.215
Iter 183/200 - Loss: 0.227
Iter 184/200 - Loss: 0.227
Iter 185/200 - Loss: 0.226
Iter 186/200 - Loss: 0.231
Iter 187/200 - Loss: 0.231
Iter 188/200 - Loss: 0.210
Iter 189/200 - Loss: 0.226
Iter 190/200 - Loss: 0.216
Iter 191/200 - Loss: 0.218
Iter 192/200 - Loss: 0.219
Iter 193/200 - Loss: 0.225
Iter 194/200 - Loss: 0.213
Iter 195/200 - Loss: 0.203
Iter 196/200 - Loss: 0.223
Iter 197/200 - Loss: 0.210
Iter 198/200 - Loss: 0.211
Iter 199/200 - Loss: 0.220
Iter 200/200 - Loss: 0.201
CPU times: user 16.5 s, sys: 1.18 s, total: 17.7 s
Wall time: 17.6 s

Test the model

Next we test the model and plot the decision boundary. Despite the function we are optimizing not having an obvious additive decomposition, the model provides accurate results.


In [5]:
# Switch the model and likelihood into the evaluation mode
model.eval()
likelihood.eval()

# Start the plot, 4x3in
f, ax = plt.subplots(1, 1, figsize=(4, 3))

n = 150
test_x = torch.zeros(n ** 2, 2)
test_x[:, 0].copy_(torch.linspace(-1, 1, n).repeat(n))
test_x[:, 1].copy_(torch.linspace(-1, 1, n).unsqueeze(1).repeat(1, n).view(-1))
# Cuda variable of test data
test_x = test_x.cuda()

with torch.no_grad():
    predictions = likelihood(model(test_x))

# prob<0.5 --> label -1 // prob>0.5 --> label 1
pred_labels = predictions.mean.ge(0.5).float().mul(2).sub(1).cpu()
# Colors = yellow for 1, red for -1
color = []
for i in range(len(pred_labels)):
    if pred_labels[i] == 1:
        color.append('y')
    else:
        color.append('r')
        
# Plot data a scatter plot
ax.scatter(test_x[:, 0].cpu(), test_x[:, 1].cpu(), color=color, s=1)


Out[5]:
<matplotlib.collections.PathCollection at 0x7fc556221240>

In [ ]: