Nested rejection sampling

This example demonstrates how to use nested rejection sampling [1] to sample from the posterior distribution for a logistic model fitted to model-simulated data.

Nested sampling is the craziest way to calculate an integral that you'll ever come across, which has found widespread application in physics. The idea is based upon repeatedly partitioning the prior density to a given area of parameter space based on likelihood thresholds. These repeated partitions form sort of Matryoshka dolls of spaces, where the later surfaces are "nested" within the earlier ones. The space between the Matryoshka volumes constitutes "shells", whose volume can itself be approximated. By summing the volumes of these shells, the marginal likelihood can be calculated. It's bonkers, but it works. It works especially well for multimodal distributions, where traditional methods of calculating the marginal likelihood fail. As a very useful bi-product of nested sampling, posterior samples can be produced by importance sampling.

[1] "Nested Sampling for General Bayesian Computation", John Skilling, Bayesian Analysis (2006) https://projecteuclid.org/download/pdf_1/euclid.ba/1340370944.

First create fake data.


In [1]:
import pints
import pints.toy as toy
import numpy as np
import matplotlib.pyplot as plt

# Load a forward model
model = toy.LogisticModel()

# Create some toy data
r = 0.015
k = 500
real_parameters = [r, k]
times = np.linspace(0, 1000, 100)
signal_values = model.simulate(real_parameters, times)

# Add independent Gaussian noise
sigma = 10
observed_values = signal_values + pints.noise.independent(sigma, signal_values.shape)

# Plot
plt.plot(times,signal_values,label = 'signal')
plt.plot(times,observed_values,label = 'observed')
plt.xlabel('Time')
plt.ylabel('Values')
plt.legend()
plt.show()


Create the nested sampler that will be used to sample from the posterior.


In [2]:
# Create an object with links to the model and time series
problem = pints.SingleOutputProblem(model, times, observed_values)

# Create a log-likelihood function (adds an extra parameter!)
log_likelihood = pints.GaussianLogLikelihood(problem)

# Create a uniform prior over both the parameters and the new noise variable
log_prior = pints.UniformLogPrior(
    [0.01, 400, sigma * 0.5],
    [0.02, 600, sigma * 1.5])

# Create a nested ellipsoidal rejectection sampler
sampler = pints.NestedController(log_likelihood, log_prior, method=pints.NestedRejectionSampler)

# Set number of iterations
sampler.set_iterations(3000)

# Set the number of posterior samples to generate
sampler.set_n_posterior_samples(300)

Run the sampler!


In [3]:
samples = sampler.run()
print('Done!')


Running Nested rejection sampler
Number of active points: 400
Total number of iterations: 3000
Total number of posterior samples: 300
Iter. Eval. Time m:s Delta_log(z) Acceptance rate
0     1       0:00.0 -inf          1             
0     2       0:00.0 -inf          1             
0     21      0:00.0 -inf          1             
0     41      0:00.0 -inf          1             
0     61      0:00.0 -inf          1             
0     81      0:00.0 -inf          1             
0     101     0:00.0 -inf          1             
0     121     0:00.0 -inf          1             
0     141     0:00.0 -inf          1             
0     161     0:00.0 -inf          1             
0     181     0:00.0 -inf          1             
0     201     0:00.0 -inf          1             
0     221     0:00.0 -inf          1             
0     241     0:00.0 -inf          1             
0     261     0:00.0 -inf          1             
0     281     0:00.0 -inf          1             
0     301     0:00.0 -inf          1             
0     321     0:00.0 -inf          1             
0     341     0:00.0 -inf          1             
0     361     0:00.0 -inf          1             
0     381     0:00.0 -inf          1             
1     401     0:00.0 -inf          1             
20    421     0:00.0 -12522.86197  0.952380952381
40    443     0:00.1 -8692.935714  0.930232558   
60    464     0:00.1 -6749.535314  0.9375        
80    487     0:00.1 -5960.760661  0.91954023    
100   508     0:00.1 -5169.161112  0.925925926   
120   538     0:00.1 -4484.11497   0.869565217   
140   565     0:00.1 -4045.815411  0.848484848   
160   595     0:00.1 -3457.238432  0.820512821   
180   624     0:00.1 -3156.197327  0.803571429   
200   657     0:00.1 -2915.104877  0.778210117   
220   702     0:00.1 -2656.953402  0.728476821   
240   741     0:00.1 -2518.454139  0.703812317   
260   771     0:00.1 -2328.10491   0.700808625   
280   820     0:00.2 -2195.132374  0.666666667   
300   865     0:00.2 -2082.977906  0.64516129    
320   915     0:00.2 -1942.591729  0.621359223301
340   963     0:00.2 -1800.001913  0.603907638   
360   998     0:00.2 -1689.27647   0.602006689   
380   1043    0:00.2 -1604.018798  0.590979782   
400   1114    0:00.2 -1492.391812  0.56022409    
420   1165    0:00.3 -1399.565096  0.549019608   
440   1236    0:00.3 -1311.441787  0.526315789   
460   1281    0:00.3 -1229.742485  0.522133938706
480   1364    0:00.3 -1145.722625  0.497925311   
500   1424    0:00.3 -1104.763052  0.48828125    
520   1501    0:00.3 -1044.825747  0.47229791099 
540   1563    0:00.4 -979.3983565  0.464316423   
560   1642    0:00.4 -944.601317   0.450885668277
580   1715    0:00.4 -903.856668   0.441064639   
600   1793    0:00.4 -869.5654442  0.430725054   
620   1873    0:00.4 -834.1828438  0.420909708   
640   1962    0:00.5 -778.2974514  0.409731114   
660   2054    0:00.5 -730.1749179  0.399032648   
680   2144    0:00.5 -700.1671269  0.389908257   
700   2222    0:00.5 -672.4710088  0.384193194292
720   2321    0:00.5 -636.9215653  0.374804789   
740   2431    0:00.6 -609.8814255  0.364352536   
760   2561    0:00.6 -578.0223569  0.351689033   
780   2691    0:00.6 -554.9051664  0.34046268    
800   2828    0:00.6 -535.5797065  0.329489291598
820   2987    0:00.6 -513.9365002  0.316969463   
840   3098    0:00.7 -492.4884241  0.311341735   
860   3241    0:00.7 -472.3655125  0.30271031327 
880   3425    0:00.7 -449.5596325  0.290909091   
900   3571    0:00.7 -428.663329   0.283822138   
920   3772    0:00.8 -415.5802956  0.272835113   
940   3941    0:00.8 -398.5059426  0.265461734   
960   4211    0:00.8 -377.8020555  0.251902388   
980   4425    0:00.8 -362.37723    0.243478261   
1000  4581    0:00.8 -344.3738101  0.23917723    
1020  4816    0:00.9 -325.2670325  0.230978261   
1040  5125    0:00.9 -305.8137077  0.22010582    
1060  5514    0:01.0 -292.7783879  0.207274149   
1080  5764    0:01.0 -273.4904059  0.201342282   
1100  6169    0:01.1 -263.5950722  0.190674294   
1120  6466    0:01.1 -253.5282184  0.184635674   
1140  6719    0:01.1 -243.5706387  0.180408292   
1160  7016    0:01.2 -232.2379377  0.175332527   
1180  7434    0:01.2 -219.0717086  0.167756611   
1200  7856    0:01.2 -212.3579747  0.160944206   
1220  8122    0:01.3 -204.9107414  0.157990158   
1240  8709    0:01.3 -193.9996345  0.149235768   
1260  9091    0:01.4 -187.8771439  0.144977563   
1280  9514    0:01.4 -179.2997265  0.140443274   
1300  9916    0:01.5 -171.49347    0.136612022   
1320  10412   0:01.5 -165.8781649  0.13184179    
1340  11106   0:01.6 -158.6554766  0.12516346    
1360  11460   0:01.6 -151.233046   0.122965641953
1380  11971   0:01.7 -146.438354   0.119263676   
1400  12471   0:01.7 -139.8406842  0.11598044901 
1420  12908   0:01.8 -133.7776935  0.113527343   
1440  13553   0:01.8 -128.3981969  0.109480727   
1460  14105   0:01.9 -123.322206   0.106530463   
1480  14703   0:02.0 -118.0687888  0.103474795   
1500  15316   0:02.0 -111.0895076  0.100563154   
1520  15934   0:02.1 -106.9406826  0.0978498777  
1540  16998   0:02.2 -103.1158405  0.0927822629  
1560  17880   0:02.3 -99.82342797  0.0892448513  
1580  18680   0:02.5 -96.69939537  0.0864332604  
1600  19496   0:02.5 -92.57334525  0.0837871806  
1620  20610   0:02.7 -89.1746835   0.0801583375  
1640  21357   0:02.7 -85.23899649  0.0782554755  
1660  22180   0:02.8 -81.86833143  0.0762167126  
1680  23093   0:02.9 -77.85846454  0.0740316397  
1700  23974   0:03.0 -74.63195772  0.0721133452  
1720  25398   0:03.1 -71.66681895  0.0688055044  
1740  26409   0:03.2 -68.85347155  0.0668999193  
1760  27720   0:03.3 -66.1277549   0.0644216691  
1780  29078   0:03.5 -63.49692586  0.0620684846  
1800  30495   0:03.6 -61.13172875  0.0598105998  
1820  32114   0:03.8 -58.6684506   0.0573879044  
1840  33613   0:03.9 -56.25201326  0.055399994   
1860  35643   0:04.1 -53.96645456  0.0527764379  
1880  37137   0:04.2 -52.16608835  0.0511745652  
1900  39010   0:04.4 -50.39367749  0.0492100492  
1920  40428   0:04.6 -48.39819409  0.0479664235  
1940  42657   0:04.7 -46.6805846   0.0459095534  
1960  45251   0:05.0 -45.01217513  0.0437002519  
1980  47408   0:05.2 -43.4935904   0.0421204901  
2000  49626   0:05.3 -42.11083007  0.0406289359  
2020  52444   0:05.6 -40.80980339  0.0388133118  
2040  54435   0:05.8 -39.71935239  0.037753308   
2060  59038   0:06.2 -38.60043019  0.0351308026  
2080  62527   0:06.5 -37.2758829   0.0334798075  
2100  65634   0:06.7 -35.94925229  0.0321918018  
2120  67926   0:06.9 -34.78982829  0.0313953144  
2140  70786   0:07.2 -33.46074232  0.0304037735  
2160  75757   0:07.6 -32.14866578  0.0286635614  
2180  79918   0:08.0 -31.17933821  0.0274151764  
2200  83771   0:08.3 -30.29859336  0.0263880726  
2220  87827   0:08.7 -29.43966074  0.0253926133  
2240  91863   0:09.0 -28.5957626   0.0244907777  
2260  95359   0:09.4 -27.7039827   0.0237997452  
2280  100745   0:09.9 -26.88755899  0.0227216104  
2300  106207   0:10.4 -26.02819286  0.0217376922  
2320  109816   0:10.7 -25.25182353  0.0212034803  
2340  115300   0:11.2 -24.47739589  0.0203655352  
2360  121796   0:11.8 -23.72723113  0.0194405087  
2380  128051   0:12.4 -23.73528819  0.018644585628
2400  134075   0:12.9 -23.17965258  0.0179539929  
2420  139928   0:13.5 -22.6736489   0.017344189   
2440  149469   0:14.3 -22.21379216  0.016368259   
2460  158262   0:15.1 -21.66682745  0.0155832309  
2480  166821   0:15.9 -21.08046261  0.0149019655  
2500  177399   0:16.7 -20.72258566  0.0141243736  
2520  185783   0:17.4 -20.20283395  0.0135934794  
2540  198541   0:18.5 -19.68490287  0.012819154   
2560  211083   0:19.9 -19.19529441  0.0121509567  
2580  220189   0:21.2 -18.68109056  0.011738531   
2600  229709   0:22.1 -18.22852536  0.0113384124  
2620  238697   0:22.9 -17.78068309  0.0109946831  
2640  252493   0:24.2 -17.31671312  0.0104723257  
2660  268322   0:25.7 -16.87040837  0.0099282627  
2680  283608   0:27.1 -16.4664702   0.00946300952 
2700  294338   0:28.3 -16.04131444  0.00918561057 
2720  312948   0:30.5 -15.57954436  0.00870266327 
2740  325367   0:32.7 -15.15121394  0.00843162537 
2760  339124   0:35.9 -14.7704518   0.00814822687 
2780  352438   0:38.4 -14.39155097  0.00789687477 
2800  370071   0:42.8 -14.02172634  0.00757430256 
2820  403182   0:52.8 -13.6535636   0.00700130592 
2840  419444   0:57.0 -13.56274993  0.00677733126 
2860  433167   1:00.5 -13.22855669  0.00660863698 
2880  450780   1:05.1 -12.90779864  0.00639460012 
2900  475464   1:11.2 -12.57087746  0.00610444066 
2920  497775   1:16.3 -12.24281505  0.00587082181 
2940  523789   1:22.0 -11.90509479  0.00561723689 
2960  553972   1:24.9 -11.587114    0.00534709125 
2980  584939   1:27.3 -11.29146672  0.00509803452 
3000  608174   1:29.2 -10.9867037   0.00493604531 
Done!

Plot posterior samples versus true parameter values (dashed lines)


In [4]:
# Plot output
import pints.plot

pints.plot.histogram([samples], ref_parameters=[r, k, sigma])

plt.show()



In [5]:
vTheta = samples[0]
pints.plot.pairwise(samples, kde=True)

plt.show()


Plot posterior predictive simulations versus the observed data


In [6]:
pints.plot.series(samples[:100], problem)
plt.show()


Marginal likelihood estimate

Nested sampling calculates the denominator of Bayes' rule through applying the trapezium rule to the integral,

$$Z = \int_{0}^{1} \mathcal{L}(X) dX,$$

where $X$ is the prior probability mass.


In [7]:
print('marginal log-likelihood = ' + str(sampler.marginal_log_likelihood())
      + ' ± ' + str(sampler.marginal_log_likelihood_standard_deviation()))


marginal log-likelihood = -368.1116589772709 ± 0.08064572183322265

With PINTS we can access the segments of the discretised integral, meaning we can plot the function being integrated.


In [8]:
v_log_likelihood = sampler.log_likelihood_vector()
v_log_likelihood = v_log_likelihood[:-sampler._sampler.n_active_points()]
X = sampler.prior_space()
X = X[:-1]
plt.plot(X, v_log_likelihood)
plt.xlabel('prior volume enclosed by X(L) > L')
plt.ylabel('log likelihood')
plt.show()


Examine active and inactive points at end of sampling run

At each step of the nested sampling algorithm, the point with the lowest likelihood is discarded (and inactivated) and a new active point is drawn from the prior, with the restriction of that its likelihood exceeds the discarded one. The likelihood of the inactived point essentially defines the height of a segment of the discretised integral for $Z$. Its width is approximately given by $w_i = X_{i-1}-X_{i+1}$, where $X_i = \text{exp}(-i / N)$ and $N$ is the number of active particles and $i$ is the iteration.

PINTS keeps track of active and inactive points at the end of the nested sampling run. The active points (orange) are concentrated in a region of high likelihood, whose likelihood always exceeds the discarded inactive points (blue).


In [9]:
m_active = sampler.active_points()
m_inactive = sampler.inactive_points()

f, axarr = plt.subplots(1,3,figsize=(15,6))
axarr[0].scatter(m_inactive[:,0],m_inactive[:,1])
axarr[0].scatter(m_active[:,0],m_active[:,1],alpha=0.1)
axarr[0].set_xlim([0.008,0.022])
axarr[0].set_xlabel('r')
axarr[0].set_ylabel('k')
axarr[1].scatter(m_inactive[:,0],m_inactive[:,2])
axarr[1].scatter(m_active[:,0],m_active[:,2],alpha=0.1)
axarr[1].set_xlim([0.008,0.022])
axarr[1].set_xlabel('r')
axarr[1].set_ylabel('sigma')
axarr[2].scatter(m_inactive[:,1],m_inactive[:,2])
axarr[2].scatter(m_active[:,1],m_active[:,2],alpha=0.1)
axarr[2].set_xlabel('k')
axarr[2].set_ylabel('sigma')
plt.show()


Sample some other posterior samples from recent run

In nested sampling, we can apply importance sampling to the inactivated points to generate posterior samples. In this case, the weight of each inactive point is given by $w_i \mathcal{L}_i$, where $\mathcal{L}_i$ is its likelihood. Since we use importance sampling, we can always generate an alternative set of posterior samples by re-applying this method.


In [10]:
samples_new = sampler.sample_from_posterior(1000)

pints.plot.pairwise(samples_new, kde=True)

plt.show()