MSM of EF-hand

Build an MSM on the EF-hand domain of cTnC and the catalytic Ca2+ atom.


In [39]:
from msmbuilder.dataset import dataset
import mdtraj as md
import numpy as np
from glob import glob
from mdtraj.utils import timing
import itertools
from msmbuilder.featurizer import AtomPairsFeaturizer
from msmbuilder.decomposition import tICA
from msmbuilder.cluster import MiniBatchKMeans
from msmbuilder.cluster import MiniBatchKMedoids
%matplotlib inline
from matplotlib import pyplot as plt

In [2]:
traj_list = sorted(glob("/Users/je714/wt_data/*/05*nc"))

In [3]:
top = md.load_prmtop("/Users/je714/wt_data/run8/WT-ff14SB_clean.prmtop")

In [4]:
ef_hand = top.select("resid 64 to 74")
cat_cal = top.select("resid 419")

In [5]:
atom_sel = np.append(ef_hand, cat_cal)

In [6]:
atom_sel


Out[6]:
array([ 996,  997,  998,  999, 1000, 1001, 1002, 1003, 1004, 1005, 1006,
       1007, 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017,
       1018, 1019, 1020, 1021, 1022, 1023, 1024, 1025, 1026, 1027, 1028,
       1029, 1030, 1031, 1032, 1033, 1034, 1035, 1036, 1037, 1038, 1039,
       1040, 1041, 1042, 1043, 1044, 1045, 1046, 1047, 1048, 1049, 1050,
       1051, 1052, 1053, 1054, 1055, 1056, 1057, 1058, 1059, 1060, 1061,
       1062, 1063, 1064, 1065, 1066, 1067, 1068, 1069, 1070, 1071, 1072,
       1073, 1074, 1075, 1076, 1077, 1078, 1079, 1080, 1081, 1082, 1083,
       1084, 1085, 1086, 1087, 1088, 1089, 1090, 1091, 1092, 1093, 1094,
       1095, 1096, 1097, 1098, 1099, 1100, 1101, 1102, 1103, 1104, 1105,
       1106, 1107, 1108, 1109, 1110, 1111, 1112, 1113, 1114, 1115, 1116,
       1117, 1118, 1119, 1120, 1121, 1122, 1123, 1124, 1125, 1126, 1127,
       1128, 1129, 1130, 1131, 1132, 1133, 6809])

In [7]:
xyz = dataset("/Users/je714/wt_data/*/05*nc", fmt='mdtraj', 
               topology="/Users/je714/wt_data/run8/WT-ff14SB_clean.prmtop",
               atom_indices=atom_sel)

We featurize the trajectories into the pairwise distance between all the atoms in the EF hand and the catalytic Ca2+. Since we only loaded the relevant atoms, msmbuilder renumbers the indices from 0.


In [8]:
pairs = np.asarray(list(itertools.product(list(range(139)), [138])))

In [9]:
print(pairs.shape)


(139, 2)

In [10]:
pairs_noIJ = pairs[pairs[:,0] != pairs[:,1]] # Drop out the elements that are i=j (whose distance is 0)

In [11]:
pairs_noIJ


Out[11]:
array([[  0, 138],
       [  1, 138],
       [  2, 138],
       [  3, 138],
       [  4, 138],
       [  5, 138],
       [  6, 138],
       [  7, 138],
       [  8, 138],
       [  9, 138],
       [ 10, 138],
       [ 11, 138],
       [ 12, 138],
       [ 13, 138],
       [ 14, 138],
       [ 15, 138],
       [ 16, 138],
       [ 17, 138],
       [ 18, 138],
       [ 19, 138],
       [ 20, 138],
       [ 21, 138],
       [ 22, 138],
       [ 23, 138],
       [ 24, 138],
       [ 25, 138],
       [ 26, 138],
       [ 27, 138],
       [ 28, 138],
       [ 29, 138],
       [ 30, 138],
       [ 31, 138],
       [ 32, 138],
       [ 33, 138],
       [ 34, 138],
       [ 35, 138],
       [ 36, 138],
       [ 37, 138],
       [ 38, 138],
       [ 39, 138],
       [ 40, 138],
       [ 41, 138],
       [ 42, 138],
       [ 43, 138],
       [ 44, 138],
       [ 45, 138],
       [ 46, 138],
       [ 47, 138],
       [ 48, 138],
       [ 49, 138],
       [ 50, 138],
       [ 51, 138],
       [ 52, 138],
       [ 53, 138],
       [ 54, 138],
       [ 55, 138],
       [ 56, 138],
       [ 57, 138],
       [ 58, 138],
       [ 59, 138],
       [ 60, 138],
       [ 61, 138],
       [ 62, 138],
       [ 63, 138],
       [ 64, 138],
       [ 65, 138],
       [ 66, 138],
       [ 67, 138],
       [ 68, 138],
       [ 69, 138],
       [ 70, 138],
       [ 71, 138],
       [ 72, 138],
       [ 73, 138],
       [ 74, 138],
       [ 75, 138],
       [ 76, 138],
       [ 77, 138],
       [ 78, 138],
       [ 79, 138],
       [ 80, 138],
       [ 81, 138],
       [ 82, 138],
       [ 83, 138],
       [ 84, 138],
       [ 85, 138],
       [ 86, 138],
       [ 87, 138],
       [ 88, 138],
       [ 89, 138],
       [ 90, 138],
       [ 91, 138],
       [ 92, 138],
       [ 93, 138],
       [ 94, 138],
       [ 95, 138],
       [ 96, 138],
       [ 97, 138],
       [ 98, 138],
       [ 99, 138],
       [100, 138],
       [101, 138],
       [102, 138],
       [103, 138],
       [104, 138],
       [105, 138],
       [106, 138],
       [107, 138],
       [108, 138],
       [109, 138],
       [110, 138],
       [111, 138],
       [112, 138],
       [113, 138],
       [114, 138],
       [115, 138],
       [116, 138],
       [117, 138],
       [118, 138],
       [119, 138],
       [120, 138],
       [121, 138],
       [122, 138],
       [123, 138],
       [124, 138],
       [125, 138],
       [126, 138],
       [127, 138],
       [128, 138],
       [129, 138],
       [130, 138],
       [131, 138],
       [132, 138],
       [133, 138],
       [134, 138],
       [135, 138],
       [136, 138],
       [137, 138]])

In [12]:
traj1 = xyz[0]

In [13]:
dist_featurizer = AtomPairsFeaturizer(pair_indices=pairs)

In [14]:
with timing("Featurizing into distances..."):
    dists = dist_featurizer.fit_transform(xyz)


Featurizing into distances...: 5008.790 seconds

In [15]:
tica_model = tICA(n_components=10, lag_time=1)
with timing("tica..."):
    tica_trajs_dists = tica_model.fit_transform(dists)


tica...: 2.199 seconds

In [16]:
tica_model = tICA

In [17]:
tica_trajs_dists[0].shape


Out[17]:
(2500, 10)

In [61]:
def tica_plotter(tica_trajs):
    
    txx = np.concatenate(tica_trajs)

    plt.figure(figsize=(14, 4))
    plt.subplot(1, 2, 1)
    plt.hexbin(txx[:, 0], txx[:, 1], bins='log', mincnt=1)
    plt.xlabel('tIC 1')
    plt.ylabel('tIC 2')
    cb = plt.colorbar()
    cb.set_label('log10(N)')

    plt.subplot(1, 2, 2)
    
    time = np.linspace(start=0.0, stop=txx.shape[0]*0.02, num = txx.shape[0]) / 1000
    plt.plot(time, txx[:, 0])
    plt.xlabel("Aggregated time ($\mu$s)")
    plt.ylabel("tIC 1")
    
    plt.tight_layout()

In [62]:
tica_plotter(tica_trajs_dists)



In [19]:
clusterer = MiniBatchKMeans(n_clusters=400)
with timing("Clustering tica trajs from dist featurization..."):
    clustered_trajs = clusterer.fit_transform(tica_trajs_dists)


/Users/je714/anaconda3/lib/python3.5/site-packages/sklearn/cluster/k_means_.py:1300: RuntimeWarning: init_size=300 should be larger than k=400. Setting it to 3*k
  init_size=init_size)
Clustering tica trajs from dist featurization...: 7.206 seconds

In [57]:
def cluster_plotter(trajectory, clusterer_object):
    txx = np.concatenate(trajectory)
    plt.hexbin(txx[:,0], txx[:,1], bins='log', mincnt=1)
    plt.scatter(clusterer_object.cluster_centers_[:,0],
                clusterer_object.cluster_centers_[:,1], 
                s=100, c='w')

In [59]:
cluster_plotter(tica_trajs_dists, clusterer)



In [21]:
from msmbuilder.msm import MarkovStateModel
from msmbuilder.utils import dump
msm = MarkovStateModel(lag_time=50)
msm.fit(clustered_trajs)
print("The MSM has %s states.\n" % msm.n_states_)
print(msm.left_eigenvectors_.shape)


MSM contains 1 strongly connected component above weight=0.02. Component 0 selected, with population 100.000000%
The MSM has 400 states.

(400, 400)

In [22]:
plt.hexbin(txx[:,0], txx[:,1], bins='log', mincnt=1, cmap="Greys")
plt.scatter(clusterer.cluster_centers_[:,0],
            clusterer.cluster_centers_[:,1],
            s=1e4 * msm.populations_, # size by population
            c=msm.left_eigenvectors_[:,1], # color by eigenvector
            cmap="RdBu") 
plt.colorbar(label='First dynamical eigenvector')
plt.xlabel('tIC 1')
plt.ylabel('tIC 2')
plt.tight_layout()



In [23]:
from msmbuilder.lumping import PCCAPlus
pcca = PCCAPlus.from_msm(msm, n_macrostates=5)
macro_trajs = pcca.transform(clustered_trajs)


Optimization terminated successfully.
         Current function value: -4.833465
         Iterations: 11
         Function evaluations: 147

In [24]:
plt.hexbin(txx[:,0], txx[:,1], bins='log', mincnt=1, cmap="Greys")
plt.scatter(clusterer.cluster_centers_[:,0],
            clusterer.cluster_centers_[:,1],
            s=100,
            c=pcca.microstate_mapping_,
      )
plt.xlabel('tIC 1')
plt.ylabel('tIC 2')


Out[24]:
<matplotlib.text.Text at 0x118d1b5c0>

Cross validation


In [25]:
from sklearn.pipeline import Pipeline
from sklearn.cross_validation import KFold
from sklearn.grid_search import RandomizedSearchCV
from scipy.stats.distributions import randint

In [26]:
model = Pipeline([
    ('featurizer', AtomPairsFeaturizer(pair_indices=pairs)),
    ('tica', tICA(n_components=4)),
    ('clusterer', MiniBatchKMeans()),
    ('msm', MarkovStateModel(n_timescales=3, ergodic_cutoff='on', reversible_type='mle', verbose=False))
])
search = RandomizedSearchCV(model, n_iter=50, cv=3, refit=False, param_distributions={
        'tica__lag_time':randint(1,100),    
        'clusterer__n_clusters':randint(1,1000),
        'msm__lag_time':randint(1,100)
})

In [27]:
def plot_scores(score_matrix):
    from matplotlib import pyplot as pp
    scores = np.array([[np.mean(e.cv_validation_scores),
                        np.std(e.cv_validation_scores),
                        e.parameters['tica__lag_time'],
                        e.parameters['clusterer__n_clusters'],
                        e.parameters['msm__lag_time']]
                        for e in score_matrix.grid_scores_])

    mean_score_value = scores[:,0]
    std_score_value = scores[:,1]
    lags_tica = scores[:,2]
    cluster_number = scores[:,3]
    lags_msm = scores[:,4]

    pp.figure(figsize=(14,4))
    pp.grid(False)
    interval = 2*std_score_value
    # subplot1
    pp.subplot(1,3,1, axisbg='white')
    pp.errorbar(x=lags_tica, y=mean_score_value, yerr=interval, color = 'b', fmt='o')
    pp.plot(pp.xlim(), [search.best_score_]*2, 'k-.', label = 'best')
    pp.xlabel("tICA lag time (Frame)")
    pp.ylabel("Score")
    # subplot2
    pp.subplot(1,3,2, axisbg='white')
    pp.errorbar(x=cluster_number, y=mean_score_value, yerr=interval,  color = 'r', fmt='o')
    pp.plot(pp.xlim(), [search.best_score_]*2, 'k-.', label = 'best')
    pp.xlabel("Number of clusters")
    pp.ylabel("Score")
    # subplot3
    pp.subplot(1,3,3, axisbg='white')
    pp.errorbar(x=lags_msm, y=mean_score_value, yerr=interval, color = 'g', fmt='o')
    pp.plot(pp.xlim(), [search.best_score_]*2, 'k-.', label = 'best')
    pp.xlabel("MSM lag time (Frame)")
    pp.ylabel("Score")

    pp.tight_layout()

In [28]:
# with timing("Cross validation..."):
#     xyz_scores = search.fit(xyz)
# for key, value in xyz_scores.best_params_.items():
#     print("%s\t%s" % (key, value))
# print("Score = %.2f" % xyz_scores.best_score_)

In [29]:
# plot_scores(xyz_scores)

In [30]:
model = Pipeline([
    ('tica', tICA()),
    ('clusterer', MiniBatchKMedoids()),
    ('msm', MarkovStateModel(ergodic_cutoff='on', reversible_type='mle', verbose=False))
])
search = RandomizedSearchCV(model, n_iter=100, cv=5, refit=False, param_distributions={
        'tica__lag_time':randint(1,200),
        'tica__n_components':randint(1,15),
        'clusterer__n_clusters':randint(1,1000),
        'msm__n_timescales':randint(1,15),
        'msm__lag_time':randint(1,100)
})

In [31]:
print(search)


RandomizedSearchCV(cv=5, error_score='raise',
          estimator=Pipeline(steps=[('tica', tICA(kinetic_mapping=False, lag_time=1, n_components=None, shrinkage=None,
   weighted_transform=False)), ('clusterer', MiniBatchKMedoids(batch_size=100, max_iter=5, max_no_improvement=10,
         metric='euclidean', n_clusters=8, random_state=None)), ('msm', MarkovStateModel(ergodic_cutoff='on', lag_time=1, n_timescales=None,
         prior_counts=0, reversible_type='mle', sliding_window=True,
         verbose=False))]),
          fit_params={}, iid=True, n_iter=100, n_jobs=1,
          param_distributions={'tica__lag_time': <scipy.stats._distn_infrastructure.rv_frozen object at 0x118d0c588>, 'clusterer__n_clusters': <scipy.stats._distn_infrastructure.rv_frozen object at 0x118d11eb8>, 'msm__lag_time': <scipy.stats._distn_infrastructure.rv_frozen object at 0x118d203c8>, 'msm__n_timescales': <scipy.stats._distn_infrastructure.rv_frozen object at 0x118d110b8>, 'tica__n_components': <scipy.stats._distn_infrastructure.rv_frozen object at 0x118d11fd0>},
          pre_dispatch='2*n_jobs', random_state=None, refit=False,
          scoring=None, verbose=0)

In [32]:
with timing("Cross validation on trajectories already fit to distance feature..."):
    feat_scores = search.fit(dists)


Cross validation on trajectories already fit to distance feature...: 72756.292 seconds

In [33]:
feat_scores.best_params_


Out[33]:
{'clusterer__n_clusters': 346,
 'msm__lag_time': 1,
 'msm__n_timescales': 14,
 'tica__lag_time': 118,
 'tica__n_components': 13}

In [34]:
plot_scores(feat_scores)


feat_scores.gridscores

Use best scores from cross-validation to build an MSM


In [110]:
opt_tica = tICA(lag_time=118, n_components=13)
opt_cluster = MiniBatchKMedoids(n_clusters=346)
opt_msm = MarkovStateModel(lag_time=20, n_timescales=14)

In [111]:
opt_tica_trajs = opt_tica.fit_transform(dists)

In [112]:
tica_plotter(opt_tica_trajs)



In [113]:
clustered_opt_tica_trajs = opt_cluster.fit_transform(opt_tica_trajs)

In [114]:
cluster_plotter(opt_tica_trajs, opt_cluster)



In [115]:
opt_msm.fit(clustered_opt_tica_trajs)


MSM contains 1 strongly connected component above weight=0.05. Component 0 selected, with population 100.000000%
Out[115]:
MarkovStateModel(ergodic_cutoff='on', lag_time=20, n_timescales=14,
         prior_counts=0, reversible_type='mle', sliding_window=True,
         verbose=True)

In [104]:
lag_times = [x*2 for x in range(1,55,2)]
print(lag_times)


[2, 6, 10, 14, 18, 22, 26, 30, 34, 38, 42, 46, 50, 54, 58, 62, 66, 70, 74, 78, 82, 86, 90, 94, 98, 102, 106]

In [105]:
opt_msm.timescales_[0:3]


Out[105]:
array([ 1647.30899761,   290.03715984,   262.85624916])

In [106]:
first_three_timescales = []
for lag_time in lag_times:
    msm = MarkovStateModel(lag_time=lag_time, n_timescales=14)
    msm.fit(clustered_opt_tica_trajs)
    first_three_timescales.append(msm.timescales_[0:3])


MSM contains 1 strongly connected component above weight=0.50. Component 0 selected, with population 100.000000%
MSM contains 1 strongly connected component above weight=0.17. Component 0 selected, with population 100.000000%
MSM contains 1 strongly connected component above weight=0.10. Component 0 selected, with population 100.000000%
MSM contains 1 strongly connected component above weight=0.07. Component 0 selected, with population 100.000000%
MSM contains 1 strongly connected component above weight=0.06. Component 0 selected, with population 100.000000%
MSM contains 1 strongly connected component above weight=0.05. Component 0 selected, with population 100.000000%
MSM contains 1 strongly connected component above weight=0.04. Component 0 selected, with population 100.000000%
MSM contains 1 strongly connected component above weight=0.03. Component 0 selected, with population 100.000000%
MSM contains 1 strongly connected component above weight=0.03. Component 0 selected, with population 100.000000%
MSM contains 1 strongly connected component above weight=0.03. Component 0 selected, with population 100.000000%
MSM contains 1 strongly connected component above weight=0.02. Component 0 selected, with population 100.000000%
MSM contains 1 strongly connected component above weight=0.02. Component 0 selected, with population 100.000000%
MSM contains 1 strongly connected component above weight=0.02. Component 0 selected, with population 100.000000%
MSM contains 1 strongly connected component above weight=0.02. Component 0 selected, with population 100.000000%
MSM contains 1 strongly connected component above weight=0.02. Component 0 selected, with population 100.000000%
MSM contains 1 strongly connected component above weight=0.02. Component 0 selected, with population 100.000000%
MSM contains 1 strongly connected component above weight=0.02. Component 0 selected, with population 100.000000%
MSM contains 1 strongly connected component above weight=0.01. Component 0 selected, with population 100.000000%
MSM contains 1 strongly connected component above weight=0.01. Component 0 selected, with population 100.000000%
MSM contains 1 strongly connected component above weight=0.01. Component 0 selected, with population 100.000000%
MSM contains 1 strongly connected component above weight=0.01. Component 0 selected, with population 100.000000%
MSM contains 1 strongly connected component above weight=0.01. Component 0 selected, with population 100.000000%
MSM contains 1 strongly connected component above weight=0.01. Component 0 selected, with population 100.000000%
MSM contains 1 strongly connected component above weight=0.01. Component 0 selected, with population 100.000000%
MSM contains 1 strongly connected component above weight=0.01. Component 0 selected, with population 100.000000%
MSM contains 1 strongly connected component above weight=0.01. Component 0 selected, with population 100.000000%
MSM contains 1 strongly connected component above weight=0.01. Component 0 selected, with population 100.000000%

In [107]:
first_three_timescales


Out[107]:
[array([ 2441.28606904,   537.68523741,   471.33590618]),
 array([ 3800.53591513,  1153.0898512 ,  1026.64148586]),
 array([ 4473.78904314,  1851.47683738,  1476.75611017]),
 array([ 4937.39393302,  2278.42365529,  1834.98463499]),
 array([ 5337.052124  ,  2592.75804395,  2112.34200228]),
 array([ 5820.42479581,  3041.20465075,  2284.37583172]),
 array([ 6009.66933233,  3183.34200717,  2518.18733521]),
 array([ 6630.12698798,  3486.98835877,  2696.45378387]),
 array([ 6882.01356162,  3788.08594298,  2786.70487753]),
 array([ 7235.0851213 ,  3893.07429226,  2911.84199607]),
 array([ 7514.71629133,  4019.83496324,  3015.07728047]),
 array([ 8186.32072267,  4242.92969391,  3216.96171154]),
 array([ 8707.51635214,  4353.47006543,  3336.20200915]),
 array([ 9429.73227854,  4358.73315548,  3416.64679871]),
 array([ 10067.75567161,   4427.4567985 ,   3514.37560513]),
 array([ 10671.08899655,   4389.56826556,   3512.77290621]),
 array([ 11648.36064748,   4404.44035455,   3562.73387316]),
 array([ 13566.50474013,   4405.52138613,   3657.00340174]),
 array([ 13557.47427092,   4435.21615202,   3652.00258118]),
 array([ 15066.15387607,   4461.90989105,   3676.21108785]),
 array([ 16305.2818482,   4528.7785878,   3708.2876111]),
 array([ 18591.93922575,   4550.85493859,   3726.20558207]),
 array([ 18086.46381424,   4925.11442477,   3714.67059675]),
 array([ 18077.41480134,   4798.47761144,   3705.72231196]),
 array([ 20239.66750169,   4970.7436318 ,   3747.48539211]),
 array([ 22566.46287412,   5220.88366887,   3809.03159413]),
 array([ 25926.74157779,   5230.31219083,   3992.10013471])]

In [108]:
first_timescale = []
for comb in first_three_timescales:
    first_timescale.append(comb[0])

In [109]:
plt.scatter(x=lag_times, y=first_timescale)


Out[109]:
<matplotlib.collections.PathCollection at 0x14aa47f60>

In [ ]: