In [1]:
import os,sys
sys.path.append(os.path.join(os.getcwd(),'..'))
import candlegp
import torch
from torch.autograd import Variable
import numpy
from matplotlib import pyplot
pyplot.style.use('ggplot')
%matplotlib inline
torch.manual_seed(376106123)
Out[1]:
In [2]:
X = torch.rand(100,1,out=torch.DoubleTensor())
K = torch.exp(-0.5*(X - X.t())**2/0.01) + torch.eye(100, out=torch.DoubleTensor())*1e-6
f = torch.matmul(torch.potrf(K, upper=False), torch.randn(100,3,out=torch.DoubleTensor()))
pyplot.figure(figsize=(12,6))
pyplot.plot(X.numpy(), f.numpy(), '.')
Out[2]:
In [3]:
Y = torch.max(f,1)[1].unsqueeze(1).double()
In [4]:
k_w = candlegp.kernels.White(1, variance=0.01).double()
k = candlegp.kernels.Matern32(1).double() + k_w
m = candlegp.models.SVGP(
Variable(X), Variable(Y), kern=k,
likelihood=candlegp.likelihoods.MultiClass(3),
Z=Variable(X[::5].clone()), num_latent=3, whiten=True, q_diag=True)
In [5]:
k_w.variance.requires_grad = False
m.Z.requires_grad = False
opt = torch.optim.LBFGS([p for p in m.parameters() if p.requires_grad], lr=1e-3, max_iter=40)
def eval_model():
obj = m()
opt.zero_grad()
obj.backward()
return obj
for i in range(100):
obj = m()
opt.zero_grad()
obj.backward()
opt.step(eval_model)
if i%5==0:
print(i,':',obj.data[0])
m
Out[5]:
In [6]:
def plot(m):
f = pyplot.figure(figsize=(12,6))
a1 = f.add_axes([0.05, 0.05, 0.9, 0.6])
a2 = f.add_axes([0.05, 0.7, 0.9, 0.1])
a3 = f.add_axes([0.05, 0.85, 0.9, 0.1])
xx = Variable(torch.linspace(m.X.data.min(), m.X.data.max(), 200).view(-1,1).double())
mu, var = m.predict_f(xx)
mu, var = mu.clone(), var.clone()
p, _ = m.predict_y(xx)
a3.set_xticks([])
a3.set_yticks([])
a3.set_xticks([])
for i in range(m.likelihood.num_classes):
x = m.X.data[m.Y.data.view(-1)==i]
points, = a3.plot(x.numpy(), x.numpy()*0, '.')
color=points.get_color()
a1.plot(xx.data.numpy(), mu[:,i].data.numpy(), color=color, lw=2)
a1.plot(xx.data.numpy(), mu[:,i].data.numpy() + 2*var[:,i].data.numpy()**0.5, '--', color=color)
a1.plot(xx.data.numpy(), mu[:,i].data.numpy() - 2*var[:,i].data.numpy()**0.5, '--', color=color)
a2.plot(xx.data.numpy(), p[:,i].data.numpy(), '-', color=color, lw=2)
a2.set_ylim(-0.1, 1.1)
a2.set_yticks([0, 1])
a2.set_xticks([])
In [7]:
plot(m)
In [ ]:
In [8]:
k_w = candlegp.kernels.White(1, variance=0.01).double()
k_m = candlegp.kernels.Matern32(1, lengthscales=0.1).double()
k = k_m + k_w
m = candlegp.models.SGPMC(Variable(X), Variable(Y), kern=k,
likelihood=candlegp.likelihoods.MultiClass(3),
Z=Variable(X[::5].clone()), num_latent=3)
k_m.variance.prior = candlegp.priors.Gamma(1.,1., ttype=torch.DoubleTensor)
k_m.lengthscales.prior = candlegp.priors.Gamma(2.,2., ttype=torch.DoubleTensor)
k_w.variance.requires_grad = False
m.Z.requires_grad = False
m
Out[8]:
In [9]:
opt = torch.optim.LBFGS([p for p in m.parameters() if p.requires_grad], lr=1e-3, max_iter=40)
def eval_model():
obj = m()
opt.zero_grad()
obj.backward()
return obj
for i in range(100):
obj = m()
opt.zero_grad()
obj.backward()
opt.step(eval_model)
if i%5==0:
print(i,':',obj.data[0])
m
Out[9]:
In [10]:
plot(m)
In [11]:
import candlegp.training.hmc
res = candlegp.training.hmc.hmc_sample(m, 50, epsilon=0.04, burn=50, thin=10, lmax=15)
In [ ]:
In [12]:
#def plot_from_samples(m, samples):
samples = res
f = pyplot.figure(figsize=(12,6))
a1 = f.add_axes([0.05, 0.05, 0.9, 0.6])
a2 = f.add_axes([0.05, 0.7, 0.9, 0.1])
a3 = f.add_axes([0.05, 0.85, 0.9, 0.1])
xx = torch.linspace(m.X.data.min(), m.X.data.max(), 200, out=torch.DoubleTensor()).unsqueeze(1)
Fpred = []
Ypred = []
for ps in zip(*res[1:]):
for mp,p in zip(m.parameters(),ps):
mp.set(p)
Ypred.append(m.predict_y(Variable(xx))[0].data)
Fpred.append(m.predict_f_samples(Variable(xx), 1).squeeze().t().data)
Fpred = torch.stack(Fpred, dim=0)
Ypred = torch.stack(Ypred, dim=0)
for i in range(m.likelihood.num_classes):
x = m.X.data[m.Y.data == i]
points, = a3.plot(x.numpy(), torch.zeros_like(x).numpy(), '.')
color = points.get_color()
for F in Fpred:
a1.plot(xx.numpy(), F[:,i].numpy(), color=color, lw=0.2, alpha=1.0)
for Y in Ypred:
a2.plot(xx.numpy(), Y[:,i].numpy(), color=color, lw=0.5, alpha=1.0)
a2.set_ylim(-0.1, 1.1)
a2.set_yticks([0, 1])
a2.set_xticks([])
a3.set_xticks([])
a3.set_yticks([])
#plot_from_samples(m, res)
Out[12]:
In [13]:
lengthscaleidx = [i for i,(n,_) in enumerate(m.named_parameters()) if n.endswith(".lengthscales")][0]
pyplot.hist(numpy.array(res[1+lengthscaleidx]), normed=True, bins=50)
pyplot.xlabel('lengthscale')
Out[13]: