BayesianMarkovStateModelThis example demonstrates the class BayesianMarkovStateModel, which uses Metropolis Markov chain Monte Carlo (MCMC) to sample
over the posterior distribution of transition matrices, given the observed transitions in your dataset. This can be useful
for evaluating the uncertainty due to sampling in your dataset.
In [ ]:
%matplotlib inline
import numpy as np
from matplotlib import pyplot as plt
from mdtraj.utils import timing
from msmbuilder.example_datasets import load_doublewell
from msmbuilder.cluster import NDGrid
from msmbuilder.msm import BayesianMarkovStateModel, MarkovStateModel
In [ ]:
trjs = load_doublewell(random_state=0)['trajectories']
plt.hist(np.concatenate(trjs), bins=50, log=True)
plt.ylabel('Frequency')
plt.show()
In [ ]:
clusterer = NDGrid(n_bins_per_feature=10)
mle_msm = MarkovStateModel(lag_time=100)
b_msm = BayesianMarkovStateModel(lag_time=100, n_samples=10000, n_steps=1000)
states = clusterer.fit_transform(trjs)
with timing('running mcmc'):
b_msm.fit(states)
mle_msm.fit(states)
In [ ]:
plt.subplot(2, 1, 1)
plt.plot(b_msm.all_transmats_[:, 0, 0])
plt.axhline(mle_msm.transmat_[0, 0], c='k')
plt.ylabel('t_00')
plt.subplot(2, 1, 2)
plt.ylabel('t_23')
plt.xlabel('MCMC Iteration')
plt.plot(b_msm.all_transmats_[:, 2, 3])
plt.axhline(mle_msm.transmat_[2, 3], c='k')
plt.show()
In [ ]:
plt.plot(b_msm.all_timescales_[:, 0], label='MCMC')
plt.axhline(mle_msm.timescales_[0], c='k', label='MLE')
plt.legend(loc='best')
plt.ylabel('Longest timescale')
plt.xlabel('MCMC iteration')
plt.show()
In [ ]:
clusterer = NDGrid(n_bins_per_feature=50)
mle_msm = MarkovStateModel(lag_time=100)
b_msm = BayesianMarkovStateModel(lag_time=100, n_samples=1000, n_steps=100000)
states = clusterer.fit_transform(trjs)
with timing('running mcmc (50 states)'):
b_msm.fit(states)
mle_msm.fit(states)
In [ ]:
plt.plot(b_msm.all_timescales_[:, 0], label='MCMC')
plt.axhline(mle_msm.timescales_[0], c='k', label='MLE')
plt.legend(loc='best')
plt.ylabel('Longest timescale')
plt.xlabel('MCMC iteration')
In [ ]:
plt.plot(b_msm.all_transmats_[:, 0, 0], label='MCMC')
plt.axhline(mle_msm.transmat_[0, 0], c='k', label='MLE')
plt.legend(loc='best')
plt.ylabel('t_00')
plt.xlabel('MCMC iteration')
In [ ]: