Inference: Population MCMC

This example shows you how to use Population MCMC, also known as simulated tempering.

It follows on from the first sampling example.

First, we create a multi-modal distribution:


In [2]:
from __future__ import print_function
import pints
import pints.toy as toy
import pints.plot
import numpy as np
import matplotlib.pyplot as plt

# Load a multi-modal logpdf
log_pdf = pints.toy.MultimodalGaussianLogPDF(
    [
        [2, 2],
        [16, 12],
        [24, 24],
    ],
    [
        [[1.2, 0.0], [0.0, 1.2]],
        [[0.8, 0.2], [0.1, 1.4]],
        [[1.0, 0.5], [-0.5, 1.0]],
    ]        
)
    
# Contour plot of pdf
x = np.linspace(0, 32, 80)
y = np.linspace(0, 32, 80)
X, Y = np.meshgrid(x, y)
Z = np.exp([[log_pdf([i, j]) for i in x] for j in y])
plt.contour(X, Y, Z)
plt.xlabel('x')
plt.ylabel('y')
plt.show()


Exploration with adaptive MCMC

Now let's try exploring this landscape with adaptive covariance MCMC. In this example we use three chains, each started off near one of the modes.


In [4]:
# Choose starting points for 3 mcmc chains
xs = [[1, 1], [15, 13], [25, 23]]

# Create mcmc routine
mcmc = pints.MCMCController(log_pdf, 3, xs, method=pints.HaarioBardenetACMC)

# Add stopping criterion
mcmc.set_max_iterations(8000)

# Disable logging mode
mcmc.set_log_to_screen(False)

# Run!
print('Running...')
chains = mcmc.run()
print('Done!')

# Show traces and histograms
pints.plot.trace(chains)
plt.show()

# Discard warm up
chains = chains[:, 2000:, :]

# Check convergence using rhat criterion
print('R-hat:')
print(pints.rhat_all_params(chains))

# Check effective sample size
chains = np.vstack(chains)
print('ESS:')
print(pints.effective_sample_size(chains))


Running...
Done!
R-hat:
[11.24903076955359, 9.774488531939205]
ESS:
[3.0759802009342843, 3.0466738317607924]

In this run, each chain only explored its own mode! If you re-run, it can happen that one of the chains finds 2 or 3 modes, but the result shown above occurs quite often.

Now, we try and do the same thing with population MCMC:


In [5]:
# Create mcmc routine
mcmc = pints.MCMCController(log_pdf, 3, xs, method=pints.PopulationMCMC)

# Add stopping criterion
mcmc.set_max_iterations(8000)

# Disable logging mode
mcmc.set_log_to_screen(False)

# Run!
print('Running...')
chains = mcmc.run()
print('Done!')

# Show traces and histograms
pints.plot.trace(chains)

# Discard warm up
chains = chains[:, 2000:, :]

# Look at distribution in chain 0
pints.plot.pairwise(chains[0], kde=True)

# Show graphs
plt.show()

# Check convergence using rhat criterion
print('R-hat:')
print(pints.rhat_all_params(chains))

# Check effective sample size
chains = np.vstack(chains)
print('ESS:')
print(pints.effective_sample_size(chains))


Running...
Done!
/usr/lib64/python3.6/site-packages/matplotlib/axes/_axes.py:6462: UserWarning: The 'normed' kwarg is deprecated, and has been replaced by the 'density' kwarg.
  warnings.warn("The 'normed' kwarg is deprecated, and has been "
/usr/lib64/python3.6/site-packages/matplotlib/axes/_axes.py:6462: UserWarning: The 'normed' kwarg is deprecated, and has been replaced by the 'density' kwarg.
  warnings.warn("The 'normed' kwarg is deprecated, and has been "
R-hat:
[1.0298846623828417, 1.0300139160727695]
ESS:
[73.13271857670077, 84.67615275389988]