Scalable GP Classification in 1D (w/ KISS-GP)

This example shows how to use grid interpolation based variational classification with an AbstractVariationalGP using a GridInterpolationVariationalStrategy module. This classification module is designed for when the inputs of the function you're modeling are one-dimensional.

The use of inducing points allows for scaling up the training data by making computational complexity linear instead of cubic.

In this example, we’re modeling a function that is periodically labeled cycling every 1/8 (think of a square wave with period 1/4)

This notebook doesn't use cuda, in general we recommend GPU use if possible and most of our notebooks utilize cuda as well.

Kernel interpolation for scalable structured Gaussian processes (KISS-GP) was introduced in this paper: http://proceedings.mlr.press/v37/wilson15.pdf

KISS-GP with SVI for classification was introduced in this paper: https://papers.nips.cc/paper/6426-stochastic-variational-deep-kernel-learning.pdf


In [1]:
import math
import torch
import gpytorch
from matplotlib import pyplot as plt
from math import exp

%matplotlib inline
%load_ext autoreload
%autoreload 2


/home/gpleiss/anaconda3/envs/gpytorch/lib/python3.7/site-packages/matplotlib/__init__.py:999: UserWarning: Duplicate key in file "/home/gpleiss/.dotfiles/matplotlib/matplotlibrc", line #57
  (fname, cnt))

In [2]:
train_x = torch.linspace(0, 1, 26)
train_y = torch.sign(torch.cos(train_x * (2 * math.pi)))

In [3]:
from gpytorch.models import AbstractVariationalGP
from gpytorch.variational import CholeskyVariationalDistribution
from gpytorch.variational import GridInterpolationVariationalStrategy


class GPClassificationModel(AbstractVariationalGP):
    def __init__(self, grid_size=128, grid_bounds=[(0, 1)]):
        variational_distribution = CholeskyVariationalDistribution(grid_size)
        variational_strategy = GridInterpolationVariationalStrategy(self, grid_size, grid_bounds, variational_distribution)
        super(GPClassificationModel, self).__init__(variational_strategy)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
        
    def forward(self,x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        latent_pred = gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
        return latent_pred


model = GPClassificationModel()
likelihood = gpytorch.likelihoods.BernoulliLikelihood()

In [4]:
from gpytorch.mlls.variational_elbo import VariationalELBO

# Find optimal model hyperparameters
model.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# "Loss" for GPs - the marginal log likelihood
# n_data refers to the amount of training data
mll = VariationalELBO(likelihood, model, num_data=train_y.numel())

def train():
    num_iter = 400
    for i in range(num_iter):
        optimizer.zero_grad()
        output = model(train_x)
        # Calc loss and backprop gradients
        loss = -mll(output, train_y)
        loss.backward()
        print('Iter %d/%d - Loss: %.3f' % (i + 1, num_iter, loss.item()))
        optimizer.step()
        
# Get clock time
%time train()


Iter 1/400 - Loss: 1.416
Iter 2/400 - Loss: 6.308
Iter 3/400 - Loss: 1.661
Iter 4/400 - Loss: 2.645
Iter 5/400 - Loss: 3.797
Iter 6/400 - Loss: 3.026
Iter 7/400 - Loss: 1.811
Iter 8/400 - Loss: 1.248
Iter 9/400 - Loss: 1.638
Iter 10/400 - Loss: 2.314
Iter 11/400 - Loss: 2.120
Iter 12/400 - Loss: 1.658
Iter 13/400 - Loss: 1.406
Iter 14/400 - Loss: 1.384
Iter 15/400 - Loss: 1.697
Iter 16/400 - Loss: 1.562
Iter 17/400 - Loss: 1.542
Iter 18/400 - Loss: 1.575
Iter 19/400 - Loss: 1.089
Iter 20/400 - Loss: 1.141
Iter 21/400 - Loss: 1.178
Iter 22/400 - Loss: 1.190
Iter 23/400 - Loss: 1.254
Iter 24/400 - Loss: 1.126
Iter 25/400 - Loss: 1.201
Iter 26/400 - Loss: 1.019
Iter 27/400 - Loss: 1.325
Iter 28/400 - Loss: 1.096
Iter 29/400 - Loss: 1.051
Iter 30/400 - Loss: 1.087
Iter 31/400 - Loss: 1.243
Iter 32/400 - Loss: 1.382
Iter 33/400 - Loss: 1.190
Iter 34/400 - Loss: 1.073
Iter 35/400 - Loss: 0.978
Iter 36/400 - Loss: 0.959
Iter 37/400 - Loss: 0.908
Iter 38/400 - Loss: 1.044
Iter 39/400 - Loss: 1.112
Iter 40/400 - Loss: 0.984
Iter 41/400 - Loss: 1.073
Iter 42/400 - Loss: 1.049
Iter 43/400 - Loss: 0.998
Iter 44/400 - Loss: 0.913
Iter 45/400 - Loss: 1.095
Iter 46/400 - Loss: 1.099
Iter 47/400 - Loss: 1.002
Iter 48/400 - Loss: 1.055
Iter 49/400 - Loss: 0.982
Iter 50/400 - Loss: 1.006
Iter 51/400 - Loss: 1.004
Iter 52/400 - Loss: 1.009
Iter 53/400 - Loss: 0.973
Iter 54/400 - Loss: 1.096
Iter 55/400 - Loss: 0.940
Iter 56/400 - Loss: 0.846
Iter 57/400 - Loss: 0.814
Iter 58/400 - Loss: 0.905
Iter 59/400 - Loss: 0.908
Iter 60/400 - Loss: 0.827
Iter 61/400 - Loss: 0.901
Iter 62/400 - Loss: 1.028
Iter 63/400 - Loss: 0.999
Iter 64/400 - Loss: 0.923
Iter 65/400 - Loss: 0.907
Iter 66/400 - Loss: 1.028
Iter 67/400 - Loss: 0.952
Iter 68/400 - Loss: 1.033
Iter 69/400 - Loss: 0.899
Iter 70/400 - Loss: 0.947
Iter 71/400 - Loss: 0.875
Iter 72/400 - Loss: 0.956
Iter 73/400 - Loss: 0.815
Iter 74/400 - Loss: 0.930
Iter 75/400 - Loss: 0.944
Iter 76/400 - Loss: 0.898
Iter 77/400 - Loss: 0.800
Iter 78/400 - Loss: 0.943
Iter 79/400 - Loss: 0.936
Iter 80/400 - Loss: 0.994
Iter 81/400 - Loss: 0.891
Iter 82/400 - Loss: 0.807
Iter 83/400 - Loss: 0.959
Iter 84/400 - Loss: 0.909
Iter 85/400 - Loss: 0.789
Iter 86/400 - Loss: 0.788
Iter 87/400 - Loss: 0.807
Iter 88/400 - Loss: 0.913
Iter 89/400 - Loss: 0.892
Iter 90/400 - Loss: 0.851
Iter 91/400 - Loss: 0.823
Iter 92/400 - Loss: 1.081
Iter 93/400 - Loss: 0.914
Iter 94/400 - Loss: 1.141
Iter 95/400 - Loss: 0.904
Iter 96/400 - Loss: 0.927
Iter 97/400 - Loss: 0.947
Iter 98/400 - Loss: 0.952
Iter 99/400 - Loss: 0.816
Iter 100/400 - Loss: 0.869
Iter 101/400 - Loss: 1.034
Iter 102/400 - Loss: 0.882
Iter 103/400 - Loss: 0.910
Iter 104/400 - Loss: 0.817
Iter 105/400 - Loss: 0.950
Iter 106/400 - Loss: 0.837
Iter 107/400 - Loss: 0.794
Iter 108/400 - Loss: 0.771
Iter 109/400 - Loss: 1.031
Iter 110/400 - Loss: 1.059
Iter 111/400 - Loss: 1.019
Iter 112/400 - Loss: 0.814
Iter 113/400 - Loss: 0.913
Iter 114/400 - Loss: 0.930
Iter 115/400 - Loss: 0.999
Iter 116/400 - Loss: 0.927
Iter 117/400 - Loss: 0.934
Iter 118/400 - Loss: 0.968
Iter 119/400 - Loss: 1.034
Iter 120/400 - Loss: 0.869
Iter 121/400 - Loss: 1.076
Iter 122/400 - Loss: 0.838
Iter 123/400 - Loss: 0.759
Iter 124/400 - Loss: 0.802
Iter 125/400 - Loss: 0.899
Iter 126/400 - Loss: 0.940
Iter 127/400 - Loss: 0.731
Iter 128/400 - Loss: 0.889
Iter 129/400 - Loss: 0.872
Iter 130/400 - Loss: 0.989
Iter 131/400 - Loss: 0.925
Iter 132/400 - Loss: 0.852
Iter 133/400 - Loss: 0.883
Iter 134/400 - Loss: 0.875
Iter 135/400 - Loss: 0.975
Iter 136/400 - Loss: 0.968
Iter 137/400 - Loss: 1.225
Iter 138/400 - Loss: 0.840
Iter 139/400 - Loss: 0.980
Iter 140/400 - Loss: 0.817
Iter 141/400 - Loss: 0.775
Iter 142/400 - Loss: 0.933
Iter 143/400 - Loss: 0.951
Iter 144/400 - Loss: 0.989
Iter 145/400 - Loss: 0.872
Iter 146/400 - Loss: 0.886
Iter 147/400 - Loss: 0.841
Iter 148/400 - Loss: 0.929
Iter 149/400 - Loss: 0.786
Iter 150/400 - Loss: 0.854
Iter 151/400 - Loss: 0.886
Iter 152/400 - Loss: 0.799
Iter 153/400 - Loss: 1.028
Iter 154/400 - Loss: 0.868
Iter 155/400 - Loss: 0.835
Iter 156/400 - Loss: 0.856
Iter 157/400 - Loss: 0.749
Iter 158/400 - Loss: 0.963
Iter 159/400 - Loss: 1.119
Iter 160/400 - Loss: 0.938
Iter 161/400 - Loss: 0.938
Iter 162/400 - Loss: 0.777
Iter 163/400 - Loss: 0.927
Iter 164/400 - Loss: 0.944
Iter 165/400 - Loss: 0.781
Iter 166/400 - Loss: 0.895
Iter 167/400 - Loss: 0.805
Iter 168/400 - Loss: 0.833
Iter 169/400 - Loss: 0.776
Iter 170/400 - Loss: 0.835
Iter 171/400 - Loss: 0.747
Iter 172/400 - Loss: 0.904
Iter 173/400 - Loss: 0.718
Iter 174/400 - Loss: 0.796
Iter 175/400 - Loss: 0.886
Iter 176/400 - Loss: 0.882
Iter 177/400 - Loss: 0.827
Iter 178/400 - Loss: 0.856
Iter 179/400 - Loss: 0.799
Iter 180/400 - Loss: 0.910
Iter 181/400 - Loss: 0.766
Iter 182/400 - Loss: 0.848
Iter 183/400 - Loss: 0.792
Iter 184/400 - Loss: 0.826
Iter 185/400 - Loss: 0.862
Iter 186/400 - Loss: 0.864
Iter 187/400 - Loss: 1.120
Iter 188/400 - Loss: 0.796
Iter 189/400 - Loss: 0.810
Iter 190/400 - Loss: 0.828
Iter 191/400 - Loss: 0.842
Iter 192/400 - Loss: 0.746
Iter 193/400 - Loss: 0.889
Iter 194/400 - Loss: 1.050
Iter 195/400 - Loss: 0.816
Iter 196/400 - Loss: 0.937
Iter 197/400 - Loss: 0.759
Iter 198/400 - Loss: 0.885
Iter 199/400 - Loss: 0.872
Iter 200/400 - Loss: 0.808
Iter 201/400 - Loss: 0.799
Iter 202/400 - Loss: 0.915
Iter 203/400 - Loss: 0.884
Iter 204/400 - Loss: 0.820
Iter 205/400 - Loss: 0.801
Iter 206/400 - Loss: 0.826
Iter 207/400 - Loss: 0.839
Iter 208/400 - Loss: 0.768
Iter 209/400 - Loss: 0.902
Iter 210/400 - Loss: 0.782
Iter 211/400 - Loss: 0.766
Iter 212/400 - Loss: 1.003
Iter 213/400 - Loss: 0.859
Iter 214/400 - Loss: 1.056
Iter 215/400 - Loss: 0.744
Iter 216/400 - Loss: 0.839
Iter 217/400 - Loss: 0.773
Iter 218/400 - Loss: 0.785
Iter 219/400 - Loss: 0.979
Iter 220/400 - Loss: 0.783
Iter 221/400 - Loss: 0.759
Iter 222/400 - Loss: 0.833
Iter 223/400 - Loss: 0.826
Iter 224/400 - Loss: 0.756
Iter 225/400 - Loss: 0.854
Iter 226/400 - Loss: 0.782
Iter 227/400 - Loss: 0.812
Iter 228/400 - Loss: 0.937
Iter 229/400 - Loss: 0.821
Iter 230/400 - Loss: 0.808
Iter 231/400 - Loss: 0.960
Iter 232/400 - Loss: 1.070
Iter 233/400 - Loss: 0.675
Iter 234/400 - Loss: 0.871
Iter 235/400 - Loss: 0.924
Iter 236/400 - Loss: 0.809
Iter 237/400 - Loss: 0.847
Iter 238/400 - Loss: 0.887
Iter 239/400 - Loss: 0.716
Iter 240/400 - Loss: 0.710
Iter 241/400 - Loss: 0.938
Iter 242/400 - Loss: 0.966
Iter 243/400 - Loss: 0.722
Iter 244/400 - Loss: 0.719
Iter 245/400 - Loss: 0.852
Iter 246/400 - Loss: 0.784
Iter 247/400 - Loss: 0.964
Iter 248/400 - Loss: 0.765
Iter 249/400 - Loss: 0.717
Iter 250/400 - Loss: 0.798
Iter 251/400 - Loss: 0.710
Iter 252/400 - Loss: 0.903
Iter 253/400 - Loss: 0.797
Iter 254/400 - Loss: 0.771
Iter 255/400 - Loss: 0.735
Iter 256/400 - Loss: 0.791
Iter 257/400 - Loss: 0.888
Iter 258/400 - Loss: 0.842
Iter 259/400 - Loss: 0.756
Iter 260/400 - Loss: 0.833
Iter 261/400 - Loss: 0.735
Iter 262/400 - Loss: 0.825
Iter 263/400 - Loss: 0.819
Iter 264/400 - Loss: 0.768
Iter 265/400 - Loss: 0.820
Iter 266/400 - Loss: 0.786
Iter 267/400 - Loss: 0.752
Iter 268/400 - Loss: 0.774
Iter 269/400 - Loss: 0.777
Iter 270/400 - Loss: 0.846
Iter 271/400 - Loss: 0.811
Iter 272/400 - Loss: 0.766
Iter 273/400 - Loss: 0.795
Iter 274/400 - Loss: 0.762
Iter 275/400 - Loss: 0.837
Iter 276/400 - Loss: 0.791
Iter 277/400 - Loss: 0.824
Iter 278/400 - Loss: 0.765
Iter 279/400 - Loss: 0.972
Iter 280/400 - Loss: 0.810
Iter 281/400 - Loss: 0.789
Iter 282/400 - Loss: 0.764
Iter 283/400 - Loss: 0.744
Iter 284/400 - Loss: 0.748
Iter 285/400 - Loss: 0.704
Iter 286/400 - Loss: 0.788
Iter 287/400 - Loss: 0.714
Iter 288/400 - Loss: 0.719
Iter 289/400 - Loss: 0.725
Iter 290/400 - Loss: 0.932
Iter 291/400 - Loss: 0.716
Iter 292/400 - Loss: 0.811
Iter 293/400 - Loss: 0.749
Iter 294/400 - Loss: 0.859
Iter 295/400 - Loss: 0.872
Iter 296/400 - Loss: 0.728
Iter 297/400 - Loss: 0.814
Iter 298/400 - Loss: 0.736
Iter 299/400 - Loss: 0.713
Iter 300/400 - Loss: 0.717
Iter 301/400 - Loss: 0.798
Iter 302/400 - Loss: 0.788
Iter 303/400 - Loss: 0.712
Iter 304/400 - Loss: 0.682
Iter 305/400 - Loss: 0.726
Iter 306/400 - Loss: 0.751
Iter 307/400 - Loss: 0.751
Iter 308/400 - Loss: 0.846
Iter 309/400 - Loss: 0.707
Iter 310/400 - Loss: 0.708
Iter 311/400 - Loss: 0.785
Iter 312/400 - Loss: 0.845
Iter 313/400 - Loss: 0.817
Iter 314/400 - Loss: 0.776
Iter 315/400 - Loss: 0.718
Iter 316/400 - Loss: 0.820
Iter 317/400 - Loss: 0.843
Iter 318/400 - Loss: 0.745
Iter 319/400 - Loss: 0.748
Iter 320/400 - Loss: 0.689
Iter 321/400 - Loss: 0.774
Iter 322/400 - Loss: 0.746
Iter 323/400 - Loss: 0.685
Iter 324/400 - Loss: 0.735
Iter 325/400 - Loss: 0.715
Iter 326/400 - Loss: 0.847
Iter 327/400 - Loss: 0.660
Iter 328/400 - Loss: 0.802
Iter 329/400 - Loss: 0.686
Iter 330/400 - Loss: 0.727
Iter 331/400 - Loss: 0.703
Iter 332/400 - Loss: 0.743
Iter 333/400 - Loss: 0.655
Iter 334/400 - Loss: 0.751
Iter 335/400 - Loss: 0.718
Iter 336/400 - Loss: 0.708
Iter 337/400 - Loss: 0.641
Iter 338/400 - Loss: 0.744
Iter 339/400 - Loss: 0.788
Iter 340/400 - Loss: 0.729
Iter 341/400 - Loss: 0.692
Iter 342/400 - Loss: 0.688
Iter 343/400 - Loss: 0.747
Iter 344/400 - Loss: 0.666
Iter 345/400 - Loss: 0.693
Iter 346/400 - Loss: 0.747
Iter 347/400 - Loss: 0.668
Iter 348/400 - Loss: 0.788
Iter 349/400 - Loss: 0.698
Iter 350/400 - Loss: 0.653
Iter 351/400 - Loss: 0.701
Iter 352/400 - Loss: 0.691
Iter 353/400 - Loss: 0.772
Iter 354/400 - Loss: 0.698
Iter 355/400 - Loss: 0.667
Iter 356/400 - Loss: 0.666
Iter 357/400 - Loss: 0.683
Iter 358/400 - Loss: 0.645
Iter 359/400 - Loss: 0.700
Iter 360/400 - Loss: 0.710
Iter 361/400 - Loss: 0.733
Iter 362/400 - Loss: 0.700
Iter 363/400 - Loss: 0.672
Iter 364/400 - Loss: 0.700
Iter 365/400 - Loss: 0.718
Iter 366/400 - Loss: 0.758
Iter 367/400 - Loss: 0.693
Iter 368/400 - Loss: 0.663
Iter 369/400 - Loss: 0.647
Iter 370/400 - Loss: 0.728
Iter 371/400 - Loss: 0.669
Iter 372/400 - Loss: 0.615
Iter 373/400 - Loss: 0.695
Iter 374/400 - Loss: 0.643
Iter 375/400 - Loss: 0.640
Iter 376/400 - Loss: 0.692
Iter 377/400 - Loss: 0.622
Iter 378/400 - Loss: 0.750
Iter 379/400 - Loss: 0.729
Iter 380/400 - Loss: 0.637
Iter 381/400 - Loss: 0.695
Iter 382/400 - Loss: 0.674
Iter 383/400 - Loss: 0.680
Iter 384/400 - Loss: 0.624
Iter 385/400 - Loss: 0.702
Iter 386/400 - Loss: 0.732
Iter 387/400 - Loss: 0.699
Iter 388/400 - Loss: 0.678
Iter 389/400 - Loss: 0.601
Iter 390/400 - Loss: 0.619
Iter 391/400 - Loss: 0.638
Iter 392/400 - Loss: 0.699
Iter 393/400 - Loss: 0.730
Iter 394/400 - Loss: 0.659
Iter 395/400 - Loss: 0.702
Iter 396/400 - Loss: 0.625
Iter 397/400 - Loss: 0.638
Iter 398/400 - Loss: 0.662
Iter 399/400 - Loss: 0.648
Iter 400/400 - Loss: 0.655
CPU times: user 28.6 s, sys: 29.2 s, total: 57.7 s
Wall time: 8.28 s

In [5]:
# Set model and likelihood into eval mode
model.eval()
likelihood.eval()

# Initialize axes
f, ax = plt.subplots(1, 1, figsize=(4, 3))

with torch.no_grad():
    test_x = torch.linspace(0, 1, 101)
    predictions = likelihood(model(test_x))

ax.plot(train_x.numpy(), train_y.numpy(), 'k*')
pred_labels = predictions.mean.ge(0.5).float().mul(2).sub(1)
ax.plot(test_x.data.numpy(), pred_labels.numpy(), 'b')
ax.set_ylim([-3, 3])
ax.legend(['Observed Data', 'Mean', 'Confidence'])


Out[5]:
<matplotlib.legend.Legend at 0x7fd0f825eda0>

In [ ]: