This example shows you how to perform Bayesian inference on a Gaussian distribution and a time-series problem, using Slice Sampling with Stepout and Overrelaxation.
Slice Sampling with Stepout is a univariate method, which is applied in a Slice-Sampling-within-Gibbs framework to allow MCMC sampling from multivariate models. It generates samples by sampling uniformly from the volume underneath the posterior ($f$). It does so by introducing an auxiliary variable ($y$) and by definying a Markov chain.
If the distribution is univariate, sampling follows:
If the distribution is multivariate, we apply the univariate algorithm to each variable in turn, where the other variables are set at their current values.
In this notebook, we use the Stepout procedure to estimate the interval $I$ : we expand the initial interval by a width in each direction until both edges fall outside the slice, or until a predetermined limit is reached.
Overrelaxed steps increase sampling efficiency in highly correlated unimodal distributions by suppressing the random walk behaviour of single-variable slice sampling: each variable is still updated in turn, but rather than drawing a new value for a variable from its conditional distribution independently of the current value, the new value is instead chosen to be on the opposite side of the mode from the current value. The interval $I$ is still calculated via Stepout, and the edges $(L,R)$ are used to estimate the slice endpoints via bisection. To obtain a full sampling scheme, overrelaxed updates are alternated with normal Stepout updates.
First, we create a normal distribution with correlated parameters.
In [44]:
import os
os.chdir("../")
import pints
import pints.toy
import numpy as np
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
# Create log pdf
log_pdf = pints.toy.GaussianLogPDF([2, 4], [[1, 0.96], [0.96, 1]])
# Contour plot of pdf
levels = np.linspace(-3,12,20)
num_points = 100
x = np.linspace(-1, 5, num_points)
y = np.linspace(-0, 8, num_points)
X, Y = np.meshgrid(x, y)
Z = np.zeros(X.shape)
Z = np.exp([[log_pdf([i, j]) for i in x] for j in y])
plt.contour(X, Y, Z)
plt.xlabel('x')
plt.ylabel('y')
plt.show()
The probability of pursuing an overrelaxed step (0 as default) and the number of bisection iterations are hyperparameters. In [1], Neal suggests to set almost every update to being overrelaxed and to set a high number of interval expansion steps.
[1] Neal, R.M., 2003. Slice sampling. The annals of statistics, 31(3), pp.705-767.
In [45]:
# Choose starting point for 3 mcmc chains
xs = [
[2, 4],
[3, 3],
[5, 4],
]
# Create mcmc routine
mcmc = pints.MCMCController(log_pdf, 3, xs, method=pints.SliceStepoutMCMC)
# Add stopping criterion
mcmc.set_max_iterations(2000)
# Set up modest logging
mcmc.set_log_to_screen(True)
mcmc.set_log_interval(500)
for sampler in mcmc.samplers():
sampler.set_width([0.1, 0.1])
sampler.set_expansion_steps(100)
sampler.set_prob_overrelaxed(0.90)
sampler.set_bisection_steps(10)
# Run!
print('Running...')
full_chains_overrelaxed = mcmc.run()
print('Done!')
We first use Pints' trace tool to check the convergence of the chains.
In [46]:
# Show traces and histograms
import pints.plot
pints.plot.trace(full_chains_overrelaxed)
plt.show()
From the plots, we can see that the chains rapidly converge after a few MCMC steps.
We can further check this by using the R-hat criterion, expecting this to be approximately $1$.
In [47]:
# Check convergence using rhat criterion
print('R-hat:')
print(pints.rhat_all_params(full_chains_overrelaxed))
As our target distribution is Gaussian, we can calculate analytically the KL divergence between the distributions obtained from the chains and the true posterior distribution. We would expect the KL values to be be approximately $0$ to indicate no loss of information between the distributions.
In [48]:
# Check Kullback-Leibler divergence of chains
print(log_pdf.kl_divergence(full_chains_overrelaxed[0]))
print(log_pdf.kl_divergence(full_chains_overrelaxed[1]))
print(log_pdf.kl_divergence(full_chains_overrelaxed[2]))
We can now look at the correlation between the different parameters by using the pairwise() plot.
In [49]:
pints.plot.pairwise(full_chains_overrelaxed[2], kde=True)
plt.show()
We now create a new function to plot the average KL divergence among the different chains against the number of MCMC steps.
In [61]:
import matplotlib.pyplot as plt
def plot_kl(chains, name):
kl = []
for i in range(100):
temp = []
for s, chain in enumerate(chains):
temp.append(log_pdf.kl_divergence(chain[:i]))
kl.append(np.sum(temp)/len(chains))
plt.plot(kl)
plt.title(name)
plt.ylim(0, 10)
plt.ylabel('KL Divergence')
plt.xlabel('Iteration')
plt.show()
In [62]:
plot_kl(full_chains_overrelaxed, "90% Overrelaxation")
From the plot, we can see that the KL divergence rapidly approaches $0$ after approximately 60 MCMC steps, indicating that the distribution obtained using the samples rapidly converges to the target distribution.
We compare the KL plot for the overrelaxed run against a normal Slice Sampling with Stepout run.
In [60]:
# Create mcmc routine
mcmc = pints.MCMCController(log_pdf, 3, xs, method=pints.SliceStepoutMCMC)
# Add stopping criterion
mcmc.set_max_iterations(2000)
# Set up modest logging
mcmc.set_log_to_screen(True)
mcmc.set_log_interval(1000)
for sampler in mcmc.samplers():
sampler.set_width([0.1, 0.1])
sampler.set_expansion_steps(100)
sampler.set_prob_overrelaxed(0)
# Run!
print('Running...')
full_chains_stepout = mcmc.run()
print('Done!')
In [64]:
plot_kl(full_chains_stepout, "Stepout")
Despite being slightly better, the plot for the overrelaxed run is very similar to the one of the normal Slice Sampling with Stepout run, showing only a marginal improvement in the speed of approaching $0$.
We now test Slice Sampling with Overrelaxation on a high-dimensional Gaussian distribution with 10 dimensions.
In [71]:
# Create log pdf
log_pdf = pints.toy.HighDimensionalGaussianLogPDF(dimension=10)
# Create an adaptive covariance MCMC routine
x0 = np.random.uniform(-10, -5, size=(3, 10))
mcmc = pints.MCMCController(log_pdf, 3, x0, method=pints.SliceStepoutMCMC)
for sampler in mcmc.samplers():
sampler.set_width(1)
sampler.set_expansion_steps(50)
sampler.set_bisection_steps(10)
sampler.set_prob_overrelaxed(0.90)
# Stop after 100 iterations
mcmc.set_max_iterations(100)
mcmc.set_log_interval(50)
# Run!
print('Running...')
full_chains_HD_overrelaxation = mcmc.run()
print('Done!')
for sampler in mcmc.samplers():
sampler.set_width(1)
sampler.set_expansion_steps(50)
sampler.set_bisection_steps(0)
sampler.set_prob_overrelaxed(0)
# Stop after 100 iterations
mcmc.set_max_iterations(100)
mcmc.set_log_interval(50)
# Run!
print('Running...')
full_chains_HD_stepout = mcmc.run()
print('Done!')
In [72]:
plot_kl(full_chains_HD_overrelaxation, "Overrelaxation")
plot_kl(full_chains_HD_stepout, "Stepout")
Interestingly, for higher dimensions, the Overrelaxation run approaches $0$ more slowly compared to a normal Slice Sampling with Stepout run. This indicates that the probability of pursuing an overrelaxed run is an important hyperparameter which requires appropriate tuning.