Fitting a Model the Bayesian Way with Pystan

This notebook is a continuation from the previous one (MCMC with emcee). The first part is identical: make some fake data (emission line) and fit it with a non-linear model (Gaussian + background). But this time, we use a different MCMC package, called pystan, which uses the STAN sampler. It has many advantages over emcee, particularly being able to handle large numbers of parameters. The disadvantage is that you have to write the model in STAN's own language and it can be a pain to debug in two languages.

This notebook requires the pystan module and the corner module. You can install them (if needed) by doing:

conda install pystan
pip install corner

in the terminal.

1 Making a Fake Emission Line

The "true" data is some background flux of photons (a continuum from the source or background) plus a Gaussian line with some amplitude, width and center. I set these up as variables so it's easy to play around with them and see how things change.


In [ ]:
from numpy import *  # mmmmmm  crunchy
# Start by defining some parameters. Change these if you like!
cont_zp = 500.0
cont_slope = 5.0
amplitude = 150.0
width = 0.5
center = 5.0

# Next, a grid of wavelenght channels (assumed to have no uncertainty)
wave = linspace(0,10,100)
# The 'true' observations
flux = amplitude*exp(-0.5*power(wave-center,2)/width**2) + \
       cont_zp + cont_slope*wave
# The actual observations = true observations + Poisson noise
obs_flux = random.poisson(flux)

So we have the wavelength on the x-axis, which is assumed to have no uncertainty. The measured flux is different from the "true" flux due to Poisson noise. Let's plot the true flux and observed flux to see how things look.


In [ ]:
%matplotlib inline
from matplotlib.pyplot import subplots,plot,step,xlabel,ylabel,show,subplots
fig,ax = subplots(1,1)
ax.plot(wave, flux, 'r-')
ax.step(wave, obs_flux, color='k')
ax.set_xlabel('Wavelength (Angstroms)')
ax.set_ylabel('Counts')

2 Writing the STAN model

With emcee, we wrote functions to compute priors and likelihoods. STAN is different. It has its own language, very similar to C. The nice thing about it is it has many probability distributions built in. You also write it very similarly to the notation we use to describe Bayesian models. One important symbol is $\sim$ which you should read as "drawn from". If, for example, we had a measurement $x$ with an error $\sigma_x$ and say the true value is $x^T$, the if the errors are normally distributed, we can say: $$ x \sim N\left(x^T, \sigma_x\right) $$ that is, x is a random value drawn from a normal distribution ($N$) with center $x^T$ and standard deviation $\sigma_x$.

A STAN model is a file or string containing blocks of code. There are 3 minimum blocks needed: data, parameters, and model. The data block describes what data will be input into the STAN model and you should think of these as fixed values. The parameters block defines all the parameters your model depends on. These values will change throughout the MCMC chain and will be output at the end. The model block describes the probabilistic model that links your data to your parameters.

Here is a simple STAN model to fit the emission line. Like C, we need to declar not only the variables we're using, but their sizes if they are arrays.


In [ ]:
model_string = '''
data {
  int <lower=1> N;      // number of data points
  vector[N] wave;       // the wavelengths
  vector[N] flux;       // the observed flux values
}

parameters {
  real<lower=-1000, upper=1000> cont;  // continuum level
  real<lower=-100, upper=100> slope;   // continuum slope
  real<lower=0, upper=1000> amp;       // amplitude of Gaussian
  real <lower=0, upper=10> center;     // center of the line
  real <lower=0, upper=10> width;      // scale
}

model {
  vector[N] mod_flux;    // the model flux

  // continuum slope + Gaussian
  mod_flux = amp*exp(-0.5*square(center - wave)/square(width)) +
             cont + slope*wave;
  // Poisson is approximately Normal with sigma = sqrt(counts)
  flux ~ normal(mod_flux, sqrt(flux));
}'''

That's it. The prior on width being strictly positive is handled in the defining line in the parameters block. While we didn't use any indexing in this (thanks to casting everything in terms of vectors, STAN indexes from 1, not 0. You'll also notice I put limits on all the parameters. This is a good idea with STAN, as it will use this information to re-scale each parameter, making the computations more accurate and inital steps reasonable. The values of the limits aren't super important, just make sure they're larger than the posterior distribution of your parameters.

The next step is to compile this code into a library that pystan will use to sample the posterior. This is all done behind the scenes, but you can set verbose to True if you want to see all the details. This step is the most annoying, as you'll likely make syntax errors in the STAN code (I did while writing this), so it's another level of debugging you'll need to do. Also, the compiling can take some time (lots of optimizations being done). There is a trick for re-using compiled code, but I won't bother here.


In [ ]:
import pystan
sampler = pystan.StanModel(model_code=model_string, verbose=False)

we've now got a sampler, so it's time to sample. You do this with the sampling function. You need to give it the data to fit. You do this by constructing a dictionary with each key matching the variable name in the data block. Then tell it how many iterations (iter), how man chains (chains), and how many initial iterations to trow out (warmup). Unlike emcee, STAN does not use the parallel chains (or "walkers" in emcee-speak) to improve the sampler, but rather to test for when they have converged. We'll see this below.


In [ ]:
idata = dict(N=len(wave), wave=wave, flux=obs_flux)
output = sampler.sampling(data=idata, chains=4, iter=5000, warmup=1000)

The output variable contains lots of good stuff. For starters, if you print it out, it gives you some nice statistics on the samples. For each variable, you get the mean of the chains (best-fit values), the error in the mean (which is an estimate of error due to having finite sampling), the standard deviation (error in the best-fit), various percentiles, the effective number of samples, and a useful statistic called Rhat, or $\hat{R}$. Compute the dispersion of a parameter in each chain and take the average. Then compute the average of the parameter in each chain and compute the dispersion of that. If you take the ratio of these two quantities, you get a value close to 1.0 if the chains have converged to the same value and have the same scatter. It's a good test for whether your chains have converged.


In [ ]:
print(output)

If the Rhat values are all close to 1.0, things are good. If not, then either your chains have not converged, or the model has a problem. Multi-modal posteriors (i.e., where multiple solutions exist) are a good example. Each chain may get "stuck" in one or the other regions of high local probability and never converge to the same answer. In our case, though, Rhat should be close to 1.

You could go with the summary statistics above, but most of us like to go a bit further, plotting the traces, covariances, etc. To access the chains, we use the extract() function. In this case, I will also specify permuted=True, which will combine the chains and assign them to a dictionary for easy access. Let's plot out the chains like we did in emcee.


In [ ]:
samples = output.extract(permuted=True)
pars = list(samples.keys())
fig2,axes = subplots(len(pars),1)
for i,par in enumerate(pars):
   axes[i].plot(samples[par], '-')
   axes[i].set_ylabel(par)

Hopefully everything should look nice and converged. Note that because we threw away the first 1000 iterations (warmup), the chains are already converged. This needn't have been the case. If there was any sign of non-convergence at the beginning, you could simply run for longer. The special variable lp__ is the log-probability of the model plus some constant.

Let's plot the triangle plot as we did before.


In [ ]:
import corner
arr = array([samples[par] for par in pars[:-1]])
rmp = corner.corner(arr.T, labels=pars[:-1],
                    truths=[cont_zp,cont_slope,amplitude,center,width])

Something else we can do with the samples is build the best-fit model as well as some representative fits and plot them. The blue line will be the best fit and the grey region will represent a 3-sigma "error snake" based on 100 samples from the chains.


In [ ]:
def Gauss(x, amp, center, width, cont, slope):
  return amp*exp(-0.5*power(x-center,2)/width**2) + cont +\
        slope*x

mamp = median(samples['amp'])
mcont = median(samples['cont'])
mslope = median(samples['slope'])
mcenter = median(samples['center'])
mwidth = median(samples['width'])
xx = linspace(wave.min(), wave.max(), 100)
yy = Gauss(xx, mamp, mcenter,mwidth,mcont, mslope)
ax.plot(xx, yy, '-', color='b')
yys = [Gauss(xx, samples['amp'][i*10], samples['center'][ii*10],
             samples['width'][ii*10], samples['cont'][ii*10], 
             samples['slope'][ii*10]) \
             for ii in range(100)]
sdy = std(yys, axis=0)
ax.fill_between(xx, yy-3*sdy, yy+3*sdy, facecolor='k', alpha=0.4, zorder=10)
fig