Testing that the NUTS checkpointing works as intended.
In [32]:
import pickle
import numpy as np
import nutstrajectory
import matplotlib.pyplot as pp
In [33]:
%matplotlib inline
In [115]:
reload(nutstrajectory)
Out[115]:
In [5]:
def correlated_normal(theta):
"""
Example of a target distribution that could be sampled from using NUTS.
(Although of course you could sample from it more efficiently)
Doesn't include the normalizing constant.
"""
# Precision matrix with covariance [1, 1.98; 1.98, 4].
# A = np.linalg.inv( cov )
A = np.asarray([[50.251256, -24.874372],
[-24.874372, 12.562814]])
grad = -np.dot(theta, A)
logp = 0.5 * np.dot(grad, theta.T)
return logp, grad
In [130]:
D = 2
M = 40000
Madapt = 4000
theta0 = np.random.normal(0, 1, D)
delta = 0.2
mean = np.zeros(2)
cov = np.asarray([[1, 1.98],
[1.98, 4]])
In [131]:
!rm chain.txt save*
In [132]:
samples, lnprob, epsilon = nutstrajectory.nuts6(correlated_normal, M, Madapt, theta0, delta, outFile='chain.txt', pickleFile='save')
In [141]:
!rm chain2.txt save2*
The following cell needs to be interrupted by hand halfway through.
In [142]:
samples2, lnprob2, epsilon2 = nutstrajectory.nuts6(correlated_normal, M, Madapt, theta0, delta, outFile='chain2.txt', pickleFile='save2')
In [143]:
!wc -l chain.txt
In [144]:
!wc -l chain2.txt
In [145]:
!cp save2-lnprob.npy save3-lnprob.npy
In [146]:
!cp save2-samples.npy save3-samples.npy
In [147]:
!cp save2.pickle save3.pickle
In [148]:
!cp chain2.txt chain3.txt
In [149]:
samples2, lnprob2, epsilon2 = nutstrajectory.nuts6(correlated_normal, M, Madapt, theta0, delta, outFile='chain2.txt', pickleFile='save2')
In [150]:
samples3, lnprob3, epsilon3 = nutstrajectory.nuts6(correlated_normal, M, Madapt, theta0, delta, outFile='chain3.txt', pickleFile='save3')
In [151]:
!wc -l chain2.txt
In [152]:
!wc -l chain3.txt
In [153]:
pp.subplot(1,2,1)
pp.hist(samples[:,0],20,normed=True,histtype='step',color='r')
pp.hist(samples2[:,0],20,normed=True,histtype='step',color='g')
pp.hist(samples3[:,0],20,normed=True,histtype='step',color='b')
pp.subplot(1,2,2)
pp.hist(samples[:,1],20,normed=True,histtype='step',color='r')
pp.hist(samples2[:,1],20,normed=True,histtype='step',color='g')
pp.hist(samples3[:,1],20,normed=True,histtype='step',color='b')
Out[153]: