Minimum Variance Cluster Analysis

We are going to use a minimum variance criterion with the Jensen-Shannon divergence to coarse-grain the quad well dataset.


In [ ]:
from msmbuilder.example_datasets import QuadWell
from msmbuilder.msm import MarkovStateModel
from msmbuilder.lumping import MVCA
import numpy as np
import scipy.cluster.hierarchy
import matplotlib.pyplot as plt
% matplotlib inline

Get the dataset


In [ ]:
q = QuadWell(random_state=998).get()
ds = q['trajectories']

Define a regular spatial clusterer


In [ ]:
def regular_spatial_clustering(ds, n_bins=20, halfwidth=np.pi):
    new_ds = []
    for t in ds:
        new_t = []
        for i, f in enumerate(t):
            width = 2*halfwidth
            temp = f + halfwidth
            reg = np.floor(n_bins*temp/width)
            new_t.append(int(reg))
        new_ds.append(np.array(new_t))
    return new_ds

In [ ]:
halfwidth = max(np.abs([max(np.abs(f)) for f in ds]))[0]
assignments = regular_spatial_clustering(ds, halfwidth=halfwidth)
msm_mdl = MarkovStateModel()
msm_mdl.fit(assignments)

Plot our MSM energies


In [ ]:
def get_centers(n_bins=20, halfwidth=np.pi):
    centers = []
    tot = 2*halfwidth
    interval = tot/n_bins
    for i in range(n_bins):
        c = (i+1)*interval - interval/2. - halfwidth
        centers.append(c)
    return(centers)

In [ ]:
ccs = get_centers(halfwidth=halfwidth)
nrgs = [-0.6*np.log(p) for p in msm_mdl.populations_]
m,s,b = plt.stem(ccs, nrgs, 'deepskyblue', bottom=-1)
for i in s:
    i.set_linewidth(8)
    
potential = lambda x: 4 * (x ** 8 + 0.8 * np.exp(-80 * x ** 2) + 0.2 * np.exp(
            -80 * (x - 0.5) ** 2) +
                    0.5 * np.exp(-40 * (x + 0.5) ** 2))
exes = np.linspace(-np.pi,np.pi,1000)
whys = potential(exes)
plt.plot(exes, whys, linewidth=2, color='k')
    
plt.xlim([-halfwidth, halfwidth])
plt.ylim([0,4])

Make a model with out macrostating to get linkage information


In [ ]:
mvca = MVCA.from_msm(msm_mdl, n_macrostates=None, get_linkage=True)

Use mvca.linkage to get a scipy linkage object


In [ ]:
scipy.cluster.hierarchy.dendrogram(mvca.linkage,
                                   color_threshold=1.1,
                                   no_labels=True)
plt.show()

Use mvca.elbow_distance to get the objective function change with agglomeration


In [ ]:
for i in range(19):
    s = str(i+1)
    plt.scatter([i+1], mvca.elbow_data[i], color='k', marker=r'$%s$'%(i+1),
            s=60*(np.floor((i+1)/10)+1)) # so numbers are approximately the same size
plt.xlabel('Number of macrostates')
plt.xticks([])
plt.show()

Plot some macrostate models


In [ ]:
color_list = ['deepskyblue', 'hotpink', 'turquoise', 'indigo', 'gold',
              'olivedrab', 'orangered', 'whitesmoke']

In [ ]:
def plot_macrostates(n_macrostates=4):
    mvca_mdl = MVCA.from_msm(msm_mdl, n_macrostates=n_macrostates)
    for i, _ in enumerate(mvca_mdl.microstate_mapping_):
        m,s,b = plt.stem([ccs[i]], [nrgs[i]],
                     color_list[mvca_mdl.microstate_mapping_[i]],
                     markerfmt=' ', bottom=-1)
        for i in s:
            i.set_linewidth(5)
    plt.plot(exes, whys, color='black', linewidth=1.5)
    plt.ylim([0,4])
    plt.xlim([-halfwidth,halfwidth])

In [ ]:
plt.subplots(nrows=2, ncols=3, figsize=(12,6))

for i in range(6):
    plt.subplot(2,3,i+1)
    plot_macrostates(n_macrostates=i+2)
    plt.title('%i macrostates'%(i+2))
    
plt.tight_layout()