This example shows you how to use Population MCMC, also known as simulated tempering.
It follows on from the first sampling example.
First, we create a multi-modal distribution:
In [2]:
from __future__ import print_function
import pints
import pints.toy as toy
import pints.plot
import numpy as np
import matplotlib.pyplot as plt
# Load a multi-modal logpdf
log_pdf = pints.toy.MultimodalGaussianLogPDF(
[
[2, 2],
[16, 12],
[24, 24],
],
[
[[1.2, 0.0], [0.0, 1.2]],
[[0.8, 0.2], [0.1, 1.4]],
[[1.0, 0.5], [-0.5, 1.0]],
]
)
# Contour plot of pdf
x = np.linspace(0, 32, 80)
y = np.linspace(0, 32, 80)
X, Y = np.meshgrid(x, y)
Z = np.exp([[log_pdf([i, j]) for i in x] for j in y])
plt.contour(X, Y, Z)
plt.xlabel('x')
plt.ylabel('y')
plt.show()
In [4]:
# Choose starting points for 3 mcmc chains
xs = [[1, 1], [15, 13], [25, 23]]
# Create mcmc routine
mcmc = pints.MCMCController(log_pdf, 3, xs, method=pints.HaarioBardenetACMC)
# Add stopping criterion
mcmc.set_max_iterations(8000)
# Disable logging mode
mcmc.set_log_to_screen(False)
# Run!
print('Running...')
chains = mcmc.run()
print('Done!')
# Show traces and histograms
pints.plot.trace(chains)
plt.show()
# Discard warm up
chains = chains[:, 2000:, :]
# Check convergence using rhat criterion
print('R-hat:')
print(pints.rhat_all_params(chains))
# Check effective sample size
chains = np.vstack(chains)
print('ESS:')
print(pints.effective_sample_size(chains))
In this run, each chain only explored its own mode! If you re-run, it can happen that one of the chains finds 2 or 3 modes, but the result shown above occurs quite often.
Now, we try and do the same thing with population MCMC:
In [5]:
# Create mcmc routine
mcmc = pints.MCMCController(log_pdf, 3, xs, method=pints.PopulationMCMC)
# Add stopping criterion
mcmc.set_max_iterations(8000)
# Disable logging mode
mcmc.set_log_to_screen(False)
# Run!
print('Running...')
chains = mcmc.run()
print('Done!')
# Show traces and histograms
pints.plot.trace(chains)
# Discard warm up
chains = chains[:, 2000:, :]
# Look at distribution in chain 0
pints.plot.pairwise(chains[0], kde=True)
# Show graphs
plt.show()
# Check convergence using rhat criterion
print('R-hat:')
print(pints.rhat_all_params(chains))
# Check effective sample size
chains = np.vstack(chains)
print('ESS:')
print(pints.effective_sample_size(chains))