In [ ]:
from msmbuilder.example_datasets import QuadWell
from msmbuilder.msm import MarkovStateModel
from msmbuilder.lumping import MVCA
import numpy as np
import scipy.cluster.hierarchy
import matplotlib.pyplot as plt
% matplotlib inline
In [ ]:
q = QuadWell(random_state=998).get()
ds = q['trajectories']
In [ ]:
def regular_spatial_clustering(ds, n_bins=20, halfwidth=np.pi):
new_ds = []
for t in ds:
new_t = []
for i, f in enumerate(t):
width = 2*halfwidth
temp = f + halfwidth
reg = np.floor(n_bins*temp/width)
new_t.append(int(reg))
new_ds.append(np.array(new_t))
return new_ds
In [ ]:
halfwidth = max(np.abs([max(np.abs(f)) for f in ds]))[0]
assignments = regular_spatial_clustering(ds, halfwidth=halfwidth)
msm_mdl = MarkovStateModel()
msm_mdl.fit(assignments)
In [ ]:
def get_centers(n_bins=20, halfwidth=np.pi):
centers = []
tot = 2*halfwidth
interval = tot/n_bins
for i in range(n_bins):
c = (i+1)*interval - interval/2. - halfwidth
centers.append(c)
return(centers)
In [ ]:
ccs = get_centers(halfwidth=halfwidth)
nrgs = [-0.6*np.log(p) for p in msm_mdl.populations_]
m,s,b = plt.stem(ccs, nrgs, 'deepskyblue', bottom=-1)
for i in s:
i.set_linewidth(8)
potential = lambda x: 4 * (x ** 8 + 0.8 * np.exp(-80 * x ** 2) + 0.2 * np.exp(
-80 * (x - 0.5) ** 2) +
0.5 * np.exp(-40 * (x + 0.5) ** 2))
exes = np.linspace(-np.pi,np.pi,1000)
whys = potential(exes)
plt.plot(exes, whys, linewidth=2, color='k')
plt.xlim([-halfwidth, halfwidth])
plt.ylim([0,4])
In [ ]:
mvca = MVCA.from_msm(msm_mdl, n_macrostates=None, get_linkage=True)
Use mvca.linkage to get a scipy linkage object
In [ ]:
scipy.cluster.hierarchy.dendrogram(mvca.linkage,
color_threshold=1.1,
no_labels=True)
plt.show()
Use mvca.elbow_distance to get the objective function change with agglomeration
In [ ]:
for i in range(19):
s = str(i+1)
plt.scatter([i+1], mvca.elbow_data[i], color='k', marker=r'$%s$'%(i+1),
s=60*(np.floor((i+1)/10)+1)) # so numbers are approximately the same size
plt.xlabel('Number of macrostates')
plt.xticks([])
plt.show()
In [ ]:
color_list = ['deepskyblue', 'hotpink', 'turquoise', 'indigo', 'gold',
'olivedrab', 'orangered', 'whitesmoke']
In [ ]:
def plot_macrostates(n_macrostates=4):
mvca_mdl = MVCA.from_msm(msm_mdl, n_macrostates=n_macrostates)
for i, _ in enumerate(mvca_mdl.microstate_mapping_):
m,s,b = plt.stem([ccs[i]], [nrgs[i]],
color_list[mvca_mdl.microstate_mapping_[i]],
markerfmt=' ', bottom=-1)
for i in s:
i.set_linewidth(5)
plt.plot(exes, whys, color='black', linewidth=1.5)
plt.ylim([0,4])
plt.xlim([-halfwidth,halfwidth])
In [ ]:
plt.subplots(nrows=2, ncols=3, figsize=(12,6))
for i in range(6):
plt.subplot(2,3,i+1)
plot_macrostates(n_macrostates=i+2)
plt.title('%i macrostates'%(i+2))
plt.tight_layout()