In [1]:
from msmbuilder.example_datasets import AlanineDipeptide
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
In [2]:
trajs = AlanineDipeptide().get().trajectories
from msmbuilder.featurizer import DihedralFeaturizer,AtomPairsFeaturizer
n_atoms=trajs[0].n_atoms
pairs = []
for i in range(n_atoms):
for j in range(i):
pairs.append((i,j))
apf = AtomPairsFeaturizer(pairs)
dih = DihedralFeaturizer()
X_apf = apf.fit_transform(trajs)
X_dih = apf.fit_transform(trajs)
X = [np.hstack([X_apf[i],X_dih[i]]) for i in range(len(trajs))]
In [3]:
import pyemma
tica = pyemma.coordinates.tica(X)
X_tica = tica.get_output()
kmeans = pyemma.coordinates.cluster_mini_batch_kmeans(X_tica,k=100,max_iter=1000)
dtrajs = [dtraj.flatten() for dtraj in kmeans.get_output()]
In [4]:
def milestone(dtrajs,milestoning_set):
milestoned_dtrajs = []
for dtraj in dtrajs:
milestoned_traj = np.zeros(len(dtraj),dtype=int)
last_milestone_visited = -1
for i in range(len(dtraj)):
if dtraj[i] in milestoning_set:
last_milestone_visited = dtraj[i]
milestoned_traj[i] = last_milestone_visited
milestoned_dtrajs.append(milestoned_traj+1)
return milestoned_dtrajs
In [5]:
# choose random milestones
np.random.seed(0)
milestones = np.random.randint(0,100,10)
dtrajs_m = milestone(dtrajs,milestones)
In [6]:
dtrajs_m
Out[6]:
In [7]:
lags = range(1,101)
its_0 = pyemma.msm.its(dtrajs,lags=lags,nits=20)
its_1 = pyemma.msm.its(dtrajs_m,lags=lags,nits=20)
plt.subplot(1,2,1)
pyemma.plots.plot_implied_timescales(its_0)
plt.subplot(1,2,2)
pyemma.plots.plot_implied_timescales(its_1)
Out[7]:
In [8]:
pyemma.plots.plot_implied_timescales(its_0)
plt.figure()
pyemma.plots.plot_implied_timescales(its_1)
Out[8]:
In [9]:
lag=1
msm_0 = pyemma.msm.estimate_markov_model(dtrajs,lag=lag)
msm_1 = pyemma.msm.estimate_markov_model(dtrajs_m,lag=lag)
In [10]:
msm_0.active_count_fraction,msm_1.active_count_fraction
Out[10]:
In [11]:
sum(msm_0.eigenvalues()),sum(msm_0.eigenvalues())/len(msm_0.transition_matrix)
Out[11]:
In [12]:
sum(msm_1.eigenvalues()),sum(msm_1.eigenvalues())/len(msm_1.transition_matrix)
Out[12]:
Question: does that objective function (eqn. 4 in the paper) make sense / correlate with the metastability of the resulting MSM?
where $N_i$ and $N_{ij}$ are formed as follows.
Given a trajectory $(i,k,i,j,k,j,i,k,j)$, we break into pieces that start with $i$ and and contain all the indices until $i$ is visited again:
$N_{ij}$ is then the number of pieces that contain $j$ at least once, and $N_i$ is the number of pieces.
In [13]:
def form_gamma(dtrajs):
dtraj_stack = np.hstack(dtrajs)
n_states = len(set(dtraj_stack))
gamma = np.zeros((n_states,n_states))
for dtraj in dtrajs:
for i in range(n_states):
indices = list(np.arange(len(dtraj))[dtraj==i]) + [len(dtraj)]
for t in range(len(indices)-1):
js = set(dtraj[indices[t]:indices[t+1]])
for j in js - set([i]):
gamma[i,j] += 1
for i in range(n_states):
gamma[i] /= np.sum(dtraj_stack==i)
return gamma
In [14]:
gamma = form_gamma(dtrajs)
In [15]:
gamma
Out[15]:
In [16]:
plt.imshow(gamma,interpolation='none',cmap='Blues')
plt.colorbar()
plt.title(r'$\Gamma$')
Out[16]:
In [17]:
def metastability_index(gamma,M):
numerator = 0
for i in M:
# maximize over i
Mprime = set(M) - set([i])
for j in Mprime:
# maximize over j
if gamma[i,j] > numerator:
numerator = gamma[i,j]
denominators = []
for i in set(range(len(gamma))) - set(M):
# maximize over j
denominator = 0
for j in M:
if gamma[i,j] > denominator:
denominator = gamma[i,j]
denominators.append(denominator)
#minimize over i
denominator = min(denominators)
return numerator / denominator
In [18]:
metastability_index(gamma,milestones)
Out[18]:
In [19]:
np.random.seed(0)
def sample_random_milestone_sets(n_milestones,n_states=100,n_samples=100):
indices = np.arange(n_states)
msms = []
metastability_indices = []
for i in range(n_samples):
np.random.shuffle(indices)
milestones = indices[:n_milestones]
dtrajs_m = milestone(dtrajs,milestones)
msm_1 = pyemma.msm.estimate_markov_model(dtrajs_m,lag=lag)
msms.append(msm_1)
metastability_indices.append(metastability_index(gamma,milestones))
return msms,metastability_indices
In [20]:
msms,metastability_indices = sample_random_milestone_sets(5,n_states=100,n_samples=500)
msm_metastability = [np.trace(msm.transition_matrix) for msm in msms]
plt.scatter(metastability_indices,msm_metastability,linewidths=0)
plt.ylabel('Trace of Milestoned MSM')
plt.title(r'$\rho_\mathcal{M}$ vs. metastability')
Out[20]:
In [21]:
msms,metastability_indices = sample_random_milestone_sets(50,n_states=100,n_samples=500)
msm_metastability = [np.trace(msm.transition_matrix) for msm in msms]
plt.scatter(metastability_indices,msm_metastability,linewidths=0)
plt.ylabel('Trace of Milestoned MSM')
plt.title(r'$\rho_\mathcal{M}$ vs. metastability')
Out[21]:
In [22]:
results = []
Ms = [5,10,15] # various n_milestones
n_samples = 500 # number of samples at for each n_milestones
for M in Ms:
result = sample_random_milestone_sets(M,n_states=100,n_samples=n_samples)
results.append(result)
In [23]:
# coloring by the number of milestones
colors = np.zeros(len(Ms)*n_samples)
for i in range(len(Ms)):
colors[i*n_samples:(i+1)*n_samples] = Ms[i]
In [24]:
metastability_indices = np.hstack([r[1] for r in results])
In [25]:
msms = []
for r in results:
msms += r[0]
msm_metastability = [np.trace(msm.transition_matrix) for msm in msms]
In [26]:
np.mean(msm_metastability),np.std(msm_metastability)
Out[26]:
In [27]:
plt.scatter(metastability_indices,msm_metastability,c=colors,linewidths=0,cmap='Spectral_r')
plt.xlabel(r'Proposed "metastability index" $\rho_\mathcal{M}$')
plt.ylabel('Trace of Milestoned MSM')
plt.title(r'$\rho_\mathcal{M}$ vs. metastability')
plt.colorbar()
Out[27]:
In [28]:
plt.scatter(metastability_indices,msm_metastability/colors,c=colors,linewidths=0,cmap='Spectral_r')
plt.xlabel(r'Proposed "metastability index" $\rho_\mathcal{M}$')
plt.ylabel(r'Trace of Milestoned MSM / $|\mathcal{M}|$')
plt.title(r'$\rho_\mathcal{M}$ vs. "fractional metastability"')
plt.colorbar()
Out[28]:
In [34]:
results = []
Ms = range(2,20) # various n_milestones
n_samples = 100 # number of samples at for each n_milestones
for M in Ms:
result = sample_random_milestone_sets(M,n_states=100,n_samples=n_samples)
results.append(result)
# coloring by the number of milestones
colors = np.zeros(len(Ms)*n_samples)
for i in range(len(Ms)):
colors[i*n_samples:(i+1)*n_samples] = Ms[i]
metastability_indices = np.hstack([r[1] for r in results])
msms = []
for r in results:
msms += r[0]
msm_metastability = [np.trace(msm.transition_matrix) for msm in msms]
plt.scatter(metastability_indices,msm_metastability,c=colors,linewidths=0,cmap='Spectral_r')
plt.xlabel(r'Proposed "metastability index" $\rho_\mathcal{M}$')
plt.ylabel('Trace of Milestoned MSM')
plt.title(r'$\rho_\mathcal{M}$ vs. metastability')
plt.colorbar()
plt.figure()
plt.scatter(metastability_indices,msm_metastability/colors,c=colors,linewidths=0,cmap='Spectral_r')
plt.xlabel(r'Proposed "metastability index" $\rho_\mathcal{M}$')
plt.ylabel(r'Trace of Milestoned MSM / $|\mathcal{M}|$')
plt.title(r'$\rho_\mathcal{M}$ vs. "fractional metastability"')
plt.colorbar()
Out[34]:
In [35]:
# colors array contain
colors[np.argmin(metastability_indices)]
Out[35]:
In [36]:
np.min(metastability_indices)
Out[36]:
In [ ]:
# let's do this deterministically. If there are 10 candidate milestones, then there are 1022 possible milestone sets
In [37]:
kmeans = pyemma.coordinates.cluster_mini_batch_kmeans(X_tica,k=10,max_iter=1000)
dtrajs = [dtraj.flatten() for dtraj in kmeans.get_output()]
In [50]:
gamma = form_gamma(dtrajs)
In [59]:
plt.imshow(gamma,interpolation='none',cmap='Blues')
plt.colorbar()
Out[59]:
In [83]:
metastability_indices = []
msms = []
n_milestones = []
import itertools
for bitmask in itertools.product([0,1],repeat=10):
if np.sum(bitmask) == 0 or np.sum(bitmask) == len(bitmask):
continue
try: # sometimes msm estimation fails here
milestones = np.arange(10)[np.array(bitmask,dtype=bool)]
dtrajs_m = milestone(dtrajs,milestones)
msm_1 = pyemma.msm.estimate_markov_model(dtrajs_m,lag=lag)
msms.append(msm_1)
metastability_indices.append(metastability_index(gamma,milestones))
n_milestones.append(np.sum(bitmask))
except:
continue
msm_metastability = [np.trace(msm.transition_matrix) for msm in msms]
In [84]:
len(msm_metastability),len(metastability_indices),len(n_milestones)
Out[84]:
In [85]:
msm_metastability = [np.trace(msm.transition_matrix) for msm in msms]
len(msm_metastability)
Out[85]:
In [86]:
plt.scatter(metastability_indices,msm_metastability,c=n_milestones,linewidths=0,cmap='Spectral_r')
plt.xlabel(r'Proposed "metastability index" $\rho_\mathcal{M}$')
plt.ylabel('Trace of Milestoned MSM')
plt.title(r'$\rho_\mathcal{M}$ vs. metastability')
plt.colorbar()
plt.figure()
plt.scatter(metastability_indices,msm_metastability/np.array(n_milestones),c=n_milestones,linewidths=0,cmap='Spectral_r')
plt.xlabel(r'Proposed "metastability index" $\rho_\mathcal{M}$')
plt.ylabel(r'Trace of Milestoned MSM / $|\mathcal{M}|$')
plt.title(r'$\rho_\mathcal{M}$ vs. "fractional metastability"')
plt.colorbar()
Out[86]:
In [87]:
n_milestones[np.argmin(metastability_indices)]
Out[87]: