In [ ]:
%matplotlib inline
import numpy as np
from matplotlib import pyplot as plt
from msmbuilder.example_datasets import QuadWell, quadwell_eigs
from msmbuilder.cluster import NDGrid
from msmbuilder.msm import MarkovStateModel
from sklearn.pipeline import Pipeline

In [ ]:
dataset = QuadWell(random_state=0).get()
true_eigenvalues = quadwell_eigs(200)[0]
true_timescales = -1 / np.log(true_eigenvalues[1:])
print(QuadWell.description())

In [ ]:
def msm_timescales(trajectories, n_states):
    pipeline = Pipeline([
        ('grid', NDGrid(min=-1.2, max=1.2)),
        ('msm', MarkovStateModel(n_timescales=4, reversible_type='transpose', verbose=False))
    ])
    pipeline.set_params(grid__n_bins_per_feature=n_states)
    pipeline.fit(trajectories)
    return pipeline.named_steps['msm'].timescales_

n_states = [5, 10, 50, 100]
ts = np.array([msm_timescales(dataset.trajectories, n) for n in n_states])

In [ ]:
for i, c in enumerate(['b', 'r', 'm']):
    plt.plot(n_states, ts[:, i], c=c, marker='x')
    plt.axhline(true_timescales[i], ls='--', c=c, lw=2)

plt.xlabel('Number of states')
plt.ylabel('Timescale (steps)')
plt.show()