In this notebook I explore how to save intermediate results of the MCMC sampling to an intermediate hdf5 file.
This scenario is described in emcee Advanced Patterns as "Incrementally saving progress":
It is often useful to incrementally save the state of the chain to a file. This makes it easier to monitor the chain’s progress and it makes things a little less disastrous if your code/computer crashes somewhere in the middle of an expensive MCMC run. If you just want to append the walker positions to the end of a file, you could do something like:
f = open("chain.dat", "w")
f.close()
for result in sampler.sample(pos0, iterations=500, storechain=False):
position = result[0]
f = open("chain.dat", "a")
for k in range(position.shape[0]):
f.write("{0:4d} {1:s}\n".format(k, " ".join(position[k])))
f.close()
The --samples flag to star.py is passed to this main part of the code:
sampler = StateSampler(lnprob, p0, cov, query_lnprob=query_lnprob, acceptfn=acceptfn, rejectfn=rejectfn, debug=True, outdir=Starfish.routdir)
p, lnprob, state = sampler.run_mcmc(p0, N=args.samples)
One key insight is that StateSampler is a subclass of emcee.Sampler...
In [1]:
from emcee.sampler import Sampler
In [2]:
def bogus_lnprob(p):
return 1.0
In [3]:
samp = Sampler(3, bogus_lnprob)
In [4]:
samp.run_mcmc( # Hit shift-tab... also peak at samp.sample(), etc...
Based on reading the documentation and experimentation, I found no keyword arg (kwarg), that could be passed to achieve effortlessly the desired behavior. But it did show a path forward: Modify the StateSampler.sample class.
I arrived at a fair solution: Save every 100 samples to a numpy binary file.
vals.tofile('chain_backup.npy')
This strategy has the disadvantage that it has to rewrite (rather than append), the entire chain, which can grow quite large. But since it only happens every 100 samples, that performance hit should not be too noticeable.
Specifically I modified this section of Starfish/samplers.py:
if storechain and i % thin == 0:
ind = i0 + int(i / thin)
self._chain[ind, :] = p
self._lnprob[ind] = lnprob0
# Save every 100 samples (hardcoded!):
if ((i % 100) == 0) & (i > 100):
self._chain.tofile('chain_backup.npy')
# Heavy duty iterator action going on right here...
yield p, lnprob0, self.random_state
I hardcoded 100 as the incremental spacing, but this could easily be made into a commandline argument and keyword arg with a default.
In [7]:
import numpy as np
In [8]:
! ls
In [9]:
chain = np.fromfile('chain_backup.npy')
In [10]:
chain.shape
Out[10]:
That's slightly annoying... it saves as a flattened 1-D array. We know there are 6 parameters, so we could do:
In [11]:
n_samples, dims = (chain.shape[0]//6, 6)
In [12]:
flatchain = chain.reshape((n_samples, dims))
In [20]:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
%config InlineBackend.figure_format = 'svg' #could also do retina...
In [18]:
plt.plot(flatchain[:,0])
plt.ylabel('$T_{\mathrm{eff}}$'); plt.xlabel('Sample');
We sampled 310 iterations. The code does not get to the final 10 samples because it only saves every 100.
In [19]:
plt.plot(flatchain[0:300,0])
plt.ylabel('$T_{\mathrm{eff}}$'); plt.xlabel('Sample');
It works!