Writing a custom LogPrior

This example shows you how to write a custom LogPrior class to use in your inference problems.

Most priors in Pints are defined over just 1 parameter, and then combined into a multivariate prior using the ComposedLogPrior class. But sometimes you might want to define a prior over multiple variables at once. In such a case, it may make sense to create a custom LogPrior class.

In this example, we'll implement a uniform 2d logprior that defines a circle in the parameter space such that points within the circle have probability c = 1/A and points outside it have probability 0.

Before we begin, let's define a toy problem, like in the basic MCMC example:


In [1]:
from __future__ import print_function
import pints
import pints.toy
import pints.plot
import numpy as np
import matplotlib.pyplot as plt

# Load a forward model
model = pints.toy.LogisticModel()

# Create some toy data
real_parameters = np.array([1, 500])
times = np.linspace(0, 1000, 1000)
org_values = model.simulate(real_parameters, times)

# Add noise
noise = 0.10
values = org_values + np.random.normal(0, noise, org_values.shape)

# Create an object with links to the model and time series
problem = pints.SingleOutputProblem(model, times, values)

# Create a log-likelihood function (adds an extra parameter!)
log_likelihood = pints.GaussianKnownSigmaLogLikelihood(problem, noise)

Now let's add our prior.

First, we take a look at the LogPrior class, as well as the LogPDF class which it extends.

From this we learn that:

  1. We need to provide a method n_parameters() that returns the dimension of the parameter space our prior is defined on.
  2. Objects of our prior class should be callable: they should have a __call__ method that takes a parameter vector as input and returns a scalar.
  3. If we like, we can implement an optional sample() method.

Leaving aside the sample() method, we can implement our prior like this:


In [2]:
class CircularLogPrior(pints.LogPrior):
    def __init__(self, center, radius):
        # Store center
        self._center = np.array(center)
        # Set circle radius
        self._r = float(radius)
        # Calculate c        
        A = np.pi * self._r * self._r
        self._c = np.log(1 / A)
    
    def n_parameters(self):
        return 2
    
    def __call__(self, parameters):
        # Get coordinates relative to center
        p = np.array(parameters) - self._center
        
        # Calculate the radius
        r = np.sqrt(p[0]*p[0] + p[1]*p[1])
        
        # Return
        return self._c if r < self._r else -float('inf')

Now we can test our prior with a few values:


In [3]:
p = CircularLogPrior([0, 0], 1)
print(p([0, 0]))         # c
print(p([1, 1]))         # -inf
print(p([-0.5, -0.5]))   # c
print(p([-1.0, 0]))      # -inf
print(p([0.9, 0]))       # c


-1.14472988585
-inf
-1.14472988585
-inf
-1.14472988585

Now we can try using it in an inference routine:


In [4]:
# Create a log-prior
log_prior = CircularLogPrior(center=real_parameters, radius=0.001)

# Create a posterior log-likelihood (log(likelihood * prior))
log_posterior = pints.LogPosterior(log_likelihood, log_prior)

# Choose starting points for 3 mcmc chains
xs = [
    real_parameters + np.array([0.0005, -0.0005]),
    real_parameters + np.array([-0.0005, -0.0005]),
    real_parameters + np.array([-0.0005, 0.0005]),
]

# Create mcmc routine
mcmc = pints.MCMCController(log_posterior, 3, xs, method=pints.HaarioBardenetACMC)

# Add stopping criterion
mcmc.set_max_iterations(4000)

# Start adapting after 1000 iterations
mcmc.set_initial_phase_iterations(1000)

# Disable logging
mcmc.set_log_to_screen(False)

# Run!
print('Running...')
chains = mcmc.run()
print('Done!')

# Show traces and histograms
pints.plot.trace(chains)

# Discard warm up
chains = chains[:, 2000:, :]

# Check convergence using rhat criterion
print('R-hat:')
print(pints.rhat_all_params(chains))

# Look at distribution in chain 0
pints.plot.pairwise(chains[0], kde=True)

# Show graphs
plt.show()


Running...
Done!
R-hat:
[1.0000189021071446, 1.0019107390010724]