Deep Kernel Learning GP Regression (w/ KISS-GP)

Overview

In this notebook, we'll give a brief tutorial on how to use deep kernel learning for regression on a medium scale dataset using SKI. This also demonstrates how to incorporate standard PyTorch modules in to a Gaussian process model.


In [1]:
import math
import torch
import gpytorch
from matplotlib import pyplot as plt

# Make plots inline
%matplotlib inline

Loading Data

For this example notebook, we'll be using the elevators UCI dataset used in the paper. Running the next cell downloads a copy of the dataset that has already been scaled and normalized appropriately. For this notebook, we'll simply be splitting the data using the first 80% of the data as training and the last 20% as testing.

Note: Running the next cell will attempt to download a ~400 KB dataset file to the current directory.


In [2]:
import urllib.request
import os.path
from scipy.io import loadmat
from math import floor

if not os.path.isfile('elevators.mat'):
    print('Downloading \'elevators\' UCI dataset...')
    urllib.request.urlretrieve('https://drive.google.com/uc?export=download&id=1jhWL3YUHvXIaftia4qeAyDwVxo6j1alk', 'elevators.mat')
    
data = torch.Tensor(loadmat('elevators.mat')['data'])
X = data[:, :-1]
X = X - X.min(0)[0]
X = 2 * (X / X.max(0)[0]) - 1
y = data[:, -1]

# Use the first 80% of the data for training, and the last 20% for testing.
train_n = int(floor(0.8*len(X)))

train_x = X[:train_n, :].contiguous().cuda()
train_y = y[:train_n].contiguous().cuda()

test_x = X[train_n:, :].contiguous().cuda()
test_y = y[train_n:].contiguous().cuda()

Defining the DKL Feature Extractor

Next, we define the neural network feature extractor used to define the deep kernel. In this case, we use a fully connected network with the architecture d -> 1000 -> 500 -> 50 -> 2, as described in the original DKL paper. All of the code below uses standard PyTorch implementations of neural network layers.


In [3]:
data_dim = train_x.size(-1)

class LargeFeatureExtractor(torch.nn.Sequential):           
    def __init__(self):                                      
        super(LargeFeatureExtractor, self).__init__()        
        self.add_module('linear1', torch.nn.Linear(data_dim, 1000))
        self.add_module('relu1', torch.nn.ReLU())                  
        self.add_module('linear2', torch.nn.Linear(1000, 500))     
        self.add_module('relu2', torch.nn.ReLU())                  
        self.add_module('linear3', torch.nn.Linear(500, 50))       
        self.add_module('relu3', torch.nn.ReLU())                  
        self.add_module('linear4', torch.nn.Linear(50, 2))         
                                                             
feature_extractor = LargeFeatureExtractor().cuda()

Defining the GP Model

We now define the GP model. For more details on the use of GP models, see our simpler examples. This model uses a GridInterpolationKernel (SKI) with an RBF base kernel.

The forward method

In deep kernel learning, the forward method is where most of the interesting new stuff happens. Before calling the mean and covariance modules on the data as in the simple GP regression setting, we first pass the input data x through the neural network feature extractor. Then, to ensure that the output features of the neural network remain in the grid bounds expected by SKI, we scales the resulting features to be between 0 and 1.

Only after this processing do we call the mean and covariance module of the Gaussian process. This example also demonstrates the flexibility of defining GP models that allow for learned transformations of the data (in this case, via a neural network) before calling the mean and covariance function. Because the neural network in this case maps to two final output features, we will have no problem using SKI.


In [4]:
class GPRegressionModel(gpytorch.models.ExactGP):
        def __init__(self, train_x, train_y, likelihood):
            super(GPRegressionModel, self).__init__(train_x, train_y, likelihood)
            self.mean_module = gpytorch.means.ConstantMean()
            self.covar_module = gpytorch.kernels.GridInterpolationKernel(
                gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(ard_num_dims=2)),
                num_dims=2, grid_size=100
            )
            self.feature_extractor = feature_extractor

        def forward(self, x):
            # We're first putting our data through a deep net (feature extractor)
            # We're also scaling the features so that they're nice values
            projected_x = self.feature_extractor(x)
            projected_x = projected_x - projected_x.min(0)[0]
            projected_x = 2 * (projected_x / projected_x.max(0)[0]) - 1
        
            mean_x = self.mean_module(projected_x)
            covar_x = self.covar_module(projected_x)
            return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

In [35]:
likelihood = gpytorch.likelihoods.GaussianLikelihood().cuda()
model = GPRegressionModel(train_x, train_y, likelihood).cuda()

Training the model

The cell below trains the DKL model above, learning both the hyperparameters of the Gaussian process and the parameters of the neural network in an end-to-end fashion using Type-II MLE. We run 20 iterations of training using the Adam optimizer built in to PyTorch. With a decent GPU, this should only take a few seconds.


In [36]:
# Find optimal model hyperparameters
model.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam([
    {'params': model.feature_extractor.parameters()},
    {'params': model.covar_module.parameters()},
    {'params': model.mean_module.parameters()},
    {'params': model.likelihood.parameters()},
], lr=0.01)

# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

training_iterations = 60
def train():
    for i in range(training_iterations):
        # Zero backprop gradients
        optimizer.zero_grad()
        # Get output from model
        output = model(train_x)
        # Calc loss and backprop derivatives
        loss = -mll(output, train_y)
        loss.backward()
        print('Iter %d/%d - Loss: %.3f' % (i + 1, training_iterations, loss.item()))
        optimizer.step()
        
# See dkl_mnist.ipynb for explanation of this flag
with gpytorch.settings.use_toeplitz(True):
    %time train()


/home/jrg365/gpytorch/gpytorch/utils/cholesky.py:14: UserWarning: torch.potrf is deprecated in favour of torch.cholesky and will be removed in the next release. Please use torch.cholesky instead and note that the :attr:`upper` argument in torch.cholesky defaults to ``False``.
  potrf_list = [sub_mat.potrf() for sub_mat in mat.view(-1, *mat.shape[-2:])]
/home/jrg365/gpytorch/gpytorch/lazy/added_diag_lazy_tensor.py:66: UserWarning: torch.potrf is deprecated in favour of torch.cholesky and will be removed in the next release. Please use torch.cholesky instead and note that the :attr:`upper` argument in torch.cholesky defaults to ``False``.
  ld_one = lr_flipped.potrf().diag().log().sum() * 2
Iter 1/60 - Loss: 0.926
Iter 2/60 - Loss: 0.922
Iter 3/60 - Loss: 0.918
Iter 4/60 - Loss: 0.913
Iter 5/60 - Loss: 0.908
Iter 6/60 - Loss: 0.906
Iter 7/60 - Loss: 0.902
Iter 8/60 - Loss: 0.895
Iter 9/60 - Loss: 0.890
Iter 10/60 - Loss: 0.885
Iter 11/60 - Loss: 0.880
Iter 12/60 - Loss: 0.875
Iter 13/60 - Loss: 0.871
Iter 14/60 - Loss: 0.866
Iter 15/60 - Loss: 0.861
Iter 16/60 - Loss: 0.856
Iter 17/60 - Loss: 0.851
Iter 18/60 - Loss: 0.846
Iter 19/60 - Loss: 0.841
Iter 20/60 - Loss: 0.836
Iter 21/60 - Loss: 0.832
Iter 22/60 - Loss: 0.827
Iter 23/60 - Loss: 0.821
Iter 24/60 - Loss: 0.816
Iter 25/60 - Loss: 0.811
Iter 26/60 - Loss: 0.806
Iter 27/60 - Loss: 0.800
Iter 28/60 - Loss: 0.795
Iter 29/60 - Loss: 0.790
Iter 30/60 - Loss: 0.785
Iter 31/60 - Loss: 0.780
Iter 32/60 - Loss: 0.774
Iter 33/60 - Loss: 0.769
Iter 34/60 - Loss: 0.764
Iter 35/60 - Loss: 0.759
Iter 36/60 - Loss: 0.754
Iter 37/60 - Loss: 0.750
Iter 38/60 - Loss: 0.745
Iter 39/60 - Loss: 0.739
Iter 40/60 - Loss: 0.734
Iter 41/60 - Loss: 0.729
Iter 42/60 - Loss: 0.724
Iter 43/60 - Loss: 0.720
Iter 44/60 - Loss: 0.715
Iter 45/60 - Loss: 0.710
Iter 46/60 - Loss: 0.705
Iter 47/60 - Loss: 0.700
Iter 48/60 - Loss: 0.695
Iter 49/60 - Loss: 0.690
Iter 50/60 - Loss: 0.685
Iter 51/60 - Loss: 0.680
Iter 52/60 - Loss: 0.675
Iter 53/60 - Loss: 0.670
Iter 54/60 - Loss: 0.666
Iter 55/60 - Loss: 0.661
Iter 56/60 - Loss: 0.656
Iter 57/60 - Loss: 0.651
Iter 58/60 - Loss: 0.645
Iter 59/60 - Loss: 0.640
Iter 60/60 - Loss: 0.635
CPU times: user 11.3 s, sys: 3.11 s, total: 14.5 s
Wall time: 14.4 s

Making Predictions

The next cell gets the predictive covariance for the test set (and also technically gets the predictive mean, stored in preds.mean()) using the standard SKI testing code, with no acceleration or precomputation.


In [37]:
model.eval()
likelihood.eval()
with torch.no_grad(), gpytorch.settings.use_toeplitz(False), gpytorch.settings.fast_pred_var():
    preds = model(test_x)


/home/jrg365/gpytorch/gpytorch/utils/cholesky.py:14: UserWarning: torch.potrf is deprecated in favour of torch.cholesky and will be removed in the next release. Please use torch.cholesky instead and note that the :attr:`upper` argument in torch.cholesky defaults to ``False``.
  potrf_list = [sub_mat.potrf() for sub_mat in mat.view(-1, *mat.shape[-2:])]
/home/jrg365/gpytorch/gpytorch/lazy/added_diag_lazy_tensor.py:66: UserWarning: torch.potrf is deprecated in favour of torch.cholesky and will be removed in the next release. Please use torch.cholesky instead and note that the :attr:`upper` argument in torch.cholesky defaults to ``False``.
  ld_one = lr_flipped.potrf().diag().log().sum() * 2

In [38]:
print('Test MAE: {}'.format(torch.mean(torch.abs(preds.mean - test_y))))


Test MAE: 0.07841506600379944

In [ ]: