In [2]:
    
import numpy as np
import scipy as sci
import matplotlib.pyplot as plt
import matplotlib
%matplotlib inline
import new_colormaps as cm
import os
import sys
import itertools as it
import warnings
import itertools
import pyemma
pyemma.__version__
import pyemma.coordinates as coor
import pyemma.msm as msm
import pyemma.plots as mplt
warnings.filterwarnings("ignore", category=DeprecationWarning)
#import shapely.geometry as geo
    
In [3]:
    
def cosd(deg):
    return np.cos(deg/180. * np.pi)
def sind(deg):
    return np.sin(deg/180. * np.pi)
def xyz2fracM(crystal_prms):
    a = crystal_prms['a']
    b = crystal_prms['b']
    c = crystal_prms['c']
    alpha = crystal_prms['alpha']
    beta = crystal_prms['beta']
    gamma = crystal_prms['gamma']
    v = np.sqrt(1-cosd(alpha)**2-cosd(beta)**2-cosd(gamma)**2 + 2*cosd(alpha)* cosd(beta)*cosd(gamma))
    r1 = [1./a, -cosd(gamma)/a/sind(gamma), (cosd(alpha)*cosd(gamma)-cosd(beta)) / a/v/sind(gamma)]
    r2 = [0, 1./b/sind(gamma), (cosd(beta)*cosd(gamma)-cosd(alpha)) / b/v/sind(gamma)]
    r3 = [0, 0, sind(gamma)/c/v]
    M = np.array([r1, r2, r3])
    return M
def frac2xyzM(crystal_prms):
    a = crystal_prms['a']
    b = crystal_prms['b']
    c = crystal_prms['c']
    alpha = crystal_prms['alpha']
    beta = crystal_prms['beta']
    gamma = crystal_prms['gamma']
    v = np.sqrt(1-cosd(alpha)**2-cosd(beta)**2-cosd(gamma)**2 + 2*cosd(alpha)* cosd(beta)*cosd(gamma))
    r1 = [a, b*cosd(gamma), c*cosd(beta)]
    r2 = [0, b*sind(gamma), c*(cosd(alpha)-cosd(beta)*cosd(gamma))/sind(gamma)]
    r3 = [0, 0, c*v/sind(gamma)]
    M = np.array([r1, r2, r3])
    return M
def wrap(x, y, z, Mfwd, Mrev):
    fxyz = np.dot(Mfwd, np.array([x, y, z]))
    fxyz_ = fxyz % 1
    wrapped = np.dot(Mrev, np.array(fxyz_))
    wrapped = np.around(wrapped, 8)
    return wrapped 
# # Function to read the dump file
def read_file(infiles, frames):
    t = []
    xyz = []
    for infile in infiles:
        with open(infile, 'r') as fin:
            for line in fin:
              if len(t)<=frames:
                line_chunks = line.split()
                if line_chunks[0] != '#':
                    try:
                        ind, x, y, z = map(float, line_chunks)
                        ind = int(line_chunks[0])
                        xyz.append([x, y, z])
                    except:
                        simtime, N = map(int, line_chunks)
                        t.append(float(simtime))
    # Convert timestep to time
    #fs2ps = 0.001
    #t = (np.array(t) - t[0]) * timestep_fs * fs2ps
    # Get number of frames and reshape data array
    numframes =  len(xyz) // N
    xyz = np.reshape(xyz, (numframes, N, 3))
    if debug:
        print ("Finished read_file")
    return numframes, N, 3, xyz
def wrapcoords(alldata, Mfwd, Mrev):
    N, trj, dim = np.shape(alldata)
    alldataT= np.transpose(alldata)
    xyz_wrap_T = np.empty((dim, trj, N))
    for i in range(trj):
        frac_T = np.dot(Mfwd, alldataT[:,i,:])
        
        ## offset unit cell
        #frac_T += np.transpose([[.5, .5, 0]])
        
        # wrap unit cell coordinates
        frac_wrap_T = frac_T % 1
        # convert fractional to cartesian coords
        xyz_wrap_T[:, i, :] = np.dot(Mrev, frac_wrap_T)
    xyz_wrap = np.transpose(xyz_wrap_T)
    
    #if debug:
    #    print ("finished wrapcoords")
    return xyz_wrap
    
In [4]:
    
debug = True
timestep_fs = 1
adsorbate = 'butane'
trjs = 260  # 200, 260, 270
T = 298
Nframes = 4000
dt = .001  # convert frame indices to ns
### ABENAKI
basedir = '/home/vargaslo/ws/LAMMPS_MD/MSD_analyzed/298K'
sourcedir = os.path.join(basedir, '{}/{}/MDSim_{}K'.format(adsorbate, trjs, T))
### ANISHINAABE
sourcedir=''
dumpfile = os.path.join(sourcedir, 'lammps_3_extend.dump')
Nframes, _, Ndims, mydat_u_orig = read_file([dumpfile], Nframes)
print (np.shape(mydat_u_orig))
activedir = os.path.join('{}_N{}_T{}_ff{}'.format(adsorbate, trjs, T, Nframes))
if not os.path.exists(activedir):
    os.makedirs(activedir)
crystal = {}
crystal['a'] = 39.97
crystal['b'] = 40.00
crystal['c'] = 16.58*2
crystal['alpha'] = 90
crystal['beta'] = 90
crystal['gamma'] = 120
    
    
In [5]:
    
class data:
    def __init__(self, origdata, crystal_prms):
        self.origdata = origdata
        self.crystal_prms = crystal_prms        
        self.Mfwd = xyz2fracM(crystal_prms)
        self.Mrev = frac2xyzM(crystal_prms)      
        self.Nframes, self.Ntrjs, self.Ndims = np.shape(self.origdata)
    
    def unwrapped(self):
        return self.origdata
    
    def wrapped(self):
        return wrapcoords(self.origdata, self.Mfwd, self.Mrev)
    #-------------------------------------------------
    # create list of ndarrays(trjs,dims) using indices
    #-------------------------------------------------
    def get_xyz_region(self, hexregions=1):  # hexregions=1 is to split meso/micro
        xyz = self.unwrapped()
        xyz_wrap = self.wrapped()
        # check if inside circle
        P1 = np.dot(self.Mrev, np.array([0.0, 0.0, 0]))   #   P4------P8-------P3
        P2 = np.dot(self.Mrev, np.array([1.0, 0.0, 0]))   #    \        \        \
        P3 = np.dot(self.Mrev, np.array([1.0, 1.0, 0]))   #     \   uc   \   uc   \
        P4 = np.dot(self.Mrev, np.array([0.0, 1.0, 0]))   #      \        \        \
        P5 = np.dot(self.Mrev, np.array([0.0, 0.5, 0]))   #       P5-------P7------P9
        P6 = np.dot(self.Mrev, np.array([0.5, 0.0, 0]))   #        \        \        \
        P7 = np.dot(self.Mrev, np.array([0.5, 0.5, 0]))   #         \   uc   \   uc   \
        P8 = np.dot(self.Mrev, np.array([0.5, 1.0, 0]))   #          \        \        \
        P9 = np.dot(self.Mrev, np.array([1.0, 0.5, 0]))   #           P1-------P6------P2
        region = np.ones((self.Nframes, self.Ntrjs, 1), dtype=int)  # default to region 1 (micro)
                                                                    # region 2 is meso outer, and so on until hexregions+1
        for i in range(hexregions):
            radius = 17 * np.sqrt(1 - 1.*i/hexregions)
            for P in [P1, P2, P3, P4]:
                s2 = np.sum(np.power(xyz_wrap - P, 2)[:,:,0:2], axis=2)
                region[:,:, 0] +=  (s2 <= radius**2)*1  # regions are numbered increasing to the center of channel
        #combined = np.concatenate((xyz, region, xyz_wrap), axis=2)
        reg1 = region==1
        reg2 = region==2
        
        microxyz_w = xyz_wrap*reg1
        mesoxyz_w = xyz_wrap*reg2
    
        combined = np.concatenate((xyz, region, microxyz_w), axis=2)
        return combined
        
    def group_ind(self):
        xyzdata = self.get_xyz_region()
        indices = {}
        indices[1] = []
        indices[2] = []
            
        first = 0
        last = 0
        for p in range(self.Ntrjs):
            for i,v in enumerate(itertools.groupby(xyzdata[:,p,:], lambda x: x[3])):
                key, group = v
                N = len(list(group))
                last = first + N
                indices[int(key)].append((first, last))
            
                first = last
        return indices
    def micro_u(self, region=1):
        xyz = self.unwrapped()
        xyz_ = np.reshape(xyz, (self.Nframes * self.Ntrjs, self.Ndims), order='F')
        data = []
        for i in self.group_ind()[region]:
            ff, lf = i
            tmp = xyz_[ff:lf, :]
            data.append(tmp)
        return data
    def micro_w(self, region=1):
        xyz = self.wrapped()
        xyz_ = np.reshape(xyz, (self.Nframes * self.Ntrjs, self.Ndims), order='F')
        data = []
        for i in self.group_ind()[region]:
            ff, lf = i
            tmp = xyz_[ff:lf, :]
            data.append(tmp)
        return data
    def meso_u(self, region=2):
        xyz = self.unwrapped()
        xyz_ = np.reshape(xyz, (self.Nframes * self.Ntrjs, self.Ndims), order='F')
        data = []
        for i in self.group_ind()[region]:
            ff, lf = i
            tmp = xyz_[ff:lf, :]
            data.append(tmp)
        return data
        
    def meso_w(self, region=2):
        xyz = self.wrapped()
        xyz_ = np.reshape(xyz, (self.Nframes * self.Ntrjs, self.Ndims), order='F')
        data = []
        for i in self.group_ind()[region]:
            ff, lf = i
            tmp = xyz_[ff:lf, :]
            data.append(tmp)
        return data
    def xyz2sz(self):
        xyz = self.wrapped()
        # check if inside circle
        P1 = np.dot(self.Mrev, np.array([0.0, 0.0, 0]))   #   P4------P8-------P3
        P2 = np.dot(self.Mrev, np.array([1.0, 0.0, 0]))   #    \        \        \
        P3 = np.dot(self.Mrev, np.array([1.0, 1.0, 0]))   #     \   uc   \   uc   \
        P4 = np.dot(self.Mrev, np.array([0.0, 1.0, 0]))   #      \        \        \
        P5 = np.dot(self.Mrev, np.array([0.0, 0.5, 0]))   #       P5-------P7------P9
        P6 = np.dot(self.Mrev, np.array([0.5, 0.0, 0]))   #        \        \        \
        P7 = np.dot(self.Mrev, np.array([0.5, 0.5, 0]))   #         \   uc   \   uc   \
        P8 = np.dot(self.Mrev, np.array([0.5, 1.0, 0]))   #          \        \        \
        P9 = np.dot(self.Mrev, np.array([1.0, 0.5, 0]))   #           P1-------P6------P2
        s1 = np.sqrt(np.sum((xyz[:,:,0:2]-P1[np.newaxis, np.newaxis, 0:2])**2, axis=2))
        s2 = np.sqrt(np.sum((xyz[:,:,0:2]-P2[np.newaxis, np.newaxis, 0:2])**2, axis=2))
        s3 = np.sqrt(np.sum((xyz[:,:,0:2]-P3[np.newaxis, np.newaxis, 0:2])**2, axis=2))
        s4 = np.sqrt(np.sum((xyz[:,:,0:2]-P4[np.newaxis, np.newaxis, 0:2])**2, axis=2))
        s0 = np.minimum(s1, s2)
        s0 = np.minimum(s0, s3)
        s0 = np.minimum(s0, s4)
        # find points in microchannel
        tmp = s0>17
        
        s = tmp*s0
        z = tmp*xyz[:,:,2]
        sz = np.zeros((self.Nframes, self.Ntrjs, 2))
        sz[:,:,0] = s
        sz[:,:,1] = z
        
        return sz
    
    def tica_sz(self, region=1):  # region 1 is microchannel
        sz = self.xyz2sz()
        sz_ = np.reshape(sz, (self.Nframes * self.Ntrjs, 2), order='F')
        
#        xyz = self.wrapped()
#        xyz_ = np.reshape(xyz, (self.Nframes * self.Ntrjs, self.Ndims), order='F')
        data = []
        for i in self.group_ind()[region]:
            ff, lf = i
#            tmp = xyz_[ff:lf, :]
            tmp = sz_[ff:lf, :]
#            tmp2 = np.zeros((np.shape(tmp)[0], 2))
#            tmp2[:,0] = np.sqrt(tmp[:,0]**2 + tmp[:,1]**2)
#            tmp2[:,1] = tmp[:,2]
            data.append(tmp)
            
#        return list(np.swapaxes(sz, 0, 1))
        return data
        
x = data(mydat_u_orig, crystal)
    
In [6]:
    
# use this to visualize data
if 0:
    plt.close('all')
    for i in x.tica_sz():
        plt.plot(i[:,0], i[:,1], ls='', marker=',')
    
    plt.gca().set_aspect('equal')
    plt.show()
    
print (len(x.tica_sz()))
    
    
In [28]:
    
Y = x.tica_sz()
def longest(Y):
    # get index of longest trajectory
    longest = 0
    for i,v in enumerate(Y):
        currentlen = len(v)
        currentind = i
        if currentlen>longest:
            longest = currentlen
            index = currentind
    return index, longest
maxduration_ind, maxduration = longest(Y)
print (maxduration, maxduration_ind)
if 0:  # Show free energy?
    plt.close('all')
    pyemma.plots.plot_free_energy(np.vstack(Y)[:,0], np.vstack(Y)[:,1], cmap=cm.viridis)
    
    plt.gca().set_aspect('equal')
    plt.xlabel('Distance from mesochannel axis (A)')
    plt.ylabel('Z-distance (A)')
    plt.show()
    
    
In [29]:
    
n_clusters = 512
stride = max(1, (Nframes*trjs)//100000)
print ('stride ', stride)
clustering = coor.cluster_kmeans(Y, k=n_clusters, max_iter=200, stride=stride, fixed_seed=True)
dtrajs = clustering.dtrajs
    
    
In [41]:
    
def dtraj_fig(dtrajs):
    
    # show discrete trajectory time series
    plt.close('all')
    plt.figure()
    plt.plot(dtrajs[maxduration_ind], marker='.', ls='-', lw=.1)
    plt.ylabel('cluster')
    plt.xlabel('frame index')
    plt.tight_layout()
    
    
    plt.figure(figsize=(4,5))
    pyemma.plots.plot_free_energy(np.vstack(Y)[:,0], np.vstack(Y)[:,1], cmap=cm.viridis)
    plt.gca().set_aspect('equal')
    plt.xlabel('Distance from mesochannel axis (A)')
    plt.ylabel('Z-distance (A)')
    plt.xticks([17,19,21,23])
    plt.scatter(clustering.clustercenters[:,0], clustering.clustercenters[:,1], c='r')
    plt.savefig('{}/dtrj0_{}.png'.format(activedir, n_clusters), dpit=144)
    
    return
dtraj_fig(dtrajs)
    
    
    
In [42]:
    
its = msm.timescales_msm(dtrajs, nits=7, n_jobs=-1, lags=np.linspace(1, maxduration/3, 4, dtype=int))
    
In [43]:
    
plt.close('all')
plt.figure(figsize=(7,10))
plt.subplot(211)
mplt.plot_implied_timescales(its, ylog=True, units='steps', linewidth=2)
plt.subplot(212)
mplt.plot_implied_timescales(its, ylog=False, units='steps', linewidth=2)
plt.tight_layout()
plt.savefig('{}/its_{}.png'.format(activedir, n_clusters), dpi=144)
#plt.show()
    
    
In [56]:
    
msm_lag = Nframes//3
msm_lag = 10
M = msm.estimate_markov_model(dtrajs, msm_lag, sparse=False)
print (msm_lag)
    
    
In [57]:
    
print ('fraction of states used = ', M.active_state_fraction)
print ('fraction of counts used = ', M.active_count_fraction)
def mapcc(M):
    out = {}
    for i,v in enumerate(M.largest_connected_set):
        out[i] = {}
        out[i]['orig'] = v
        out[i]['cc'] = clustering.clustercenters[v]
    return out
mapped = mapcc(M)
    
    
In [58]:
    
# test MSM
M = msm.bayesian_markov_model(dtrajs, msm_lag, nsamples=100)  # default nsamples is 100
    
In [59]:
    
def show_trans_mat(M):
        
    plt.figure(figsize=(8,6))
    plt.imshow(np.log10(M.transition_matrix), cmap=cm.viridis, vmin=-5, vmax=-1, interpolation='nearest')
    #plt.imshow(np.log10(M.sample_mean('transition_matrix')), cmap=cm.viridis, vmin=-5, vmax=-1, interpolation='nearest')
    plt.gca().set_axis_bgcolor('k')
    plt.colorbar()
    plt.savefig('{}/trans_matrix_clusters{}_lag{}.png'.format(activedir, n_clusters, msm_lag), dpi=144)
    plt.figure(figsize=(8,6))
    plt.imshow(np.log10(M.count_matrix_active), cmap=cm.viridis, vmin=0, vmax=2, interpolation='nearest')
    plt.ylim(0,24)
    plt.xlim(0,24)
    plt.gca().set_axis_bgcolor('k')
    plt.colorbar()
    plt.savefig('{}/count_matrix_clusters{}_lag{}.png'.format(activedir, n_clusters, msm_lag), dpi=144)
    return
plt.close('all')
show_trans_mat(M)
#plt.show()
    
    
    
    
In [60]:
    
# Show eigenvalue separation
plt.close('all')
plt.figure(figsize=(7,10))
plt.subplot(211)
tmpy = dt * M.timescales()
tmpx = range(2, 2+len(tmpy))
plt.plot(tmpx, tmpy,linewidth=0,marker='o')
plt.ylabel('timescale (1 ns)'); 
plt.xlim(-0.5,18.5)
plt.ylim(0,min(10, plt.ylim()[1]))
plt.subplot(212)
tmpy_ = M.timescales()[:-1]/M.timescales()[1:]
tmpx_ = range(2, 2+len(tmpy_))
plt.bar(tmpx_, tmpy_, width=.1)#, linewidth=0,marker='o')
plt.axhline(2, c='r', ls='--')
plt.ylabel('timescale separation'); 
plt.xlim(-0.5,18.5)
plt.ylim(0, 5)
plt.tight_layout()
plt.savefig('{}/eigenval_separation_clusters{}_lag{}.png'.format(activedir, n_clusters, msm_lag), dpi=144)
print ('{:>4} {:>12} {:>12} {:>12}'.format('i', 'eigenval', 'its', 'ratio'))
for i in range(18):
    print ('{:4} {:>12.4f} {:>12.4f} {:>12.4f}'.format(i+2, M.eigenvalues()[i+1], dt*M.timescales()[i], M.timescales()[i]/M.timescales()[i+1]))
    
    
    
In [61]:
    
plt.close('all')
plt.figure(figsize=(7,7))
tmpy = M.eigenvalues()
tmpx = range(2, 2+len(tmpy))
plt.bar(tmpx, tmpy, width=.2)
plt.plot(tmpx, tmpy,linewidth=0,marker='o')
plt.axhline(0, c='r')
plt.ylabel('Eigenvalues'); 
plt.xlabel('Index'); 
plt.xlim(-0.5,124.5)
#plt.ylim(0,min(10, plt.ylim()[1]))
plt.savefig('{}/eigenvalues_clusters{}_lag{}.png'.format(activedir, n_clusters, msm_lag), dpi=144)
    
    
In [62]:
    
def plot_sampled_function(xall, yall, zall, ax=None, nbins=100, nlevels=20, cmap=cm.viridis, cbar=True, 
                          cbar_label=None, vmin=None, vmax=None):
    # histogram data
    xmin = np.amin(xall)
    xmax = np.amax(xall)
    dx = (xmax - xmin) / float(nbins)
    ymin = np.amin(yall)
    ymax = np.amax(yall)
    dy = (ymax - ymin) / float(nbins)
    # bin data
    #eps = x
    xbins = np.linspace(xmin - 0.5*dx, xmax + 0.5*dx, num=nbins)
    ybins = np.linspace(ymin - 0.5*dy, ymax + 0.5*dy, num=nbins)
    xI = np.digitize(xall, xbins)
    yI = np.digitize(yall, ybins)
    # result
    z = np.zeros((nbins, nbins))
    N = np.zeros((nbins, nbins))
    # average over bins
    for t in range(len(xall)):
        z[xI[t], yI[t]] += zall[t]
        N[xI[t], yI[t]] += 1.0
    z /= N
    # do a contour plot
    extent = [xmin, xmax, ymin, ymax]
    if ax is None:
        ax = plt.gca()
    cf = ax.contourf(z.T, nlevels, extent=extent, cmap=cmap, vmin=vmin, vmax=vmax)
    if cbar:
        cbar = plt.colorbar(cf)
        if cbar_label is not None:
            cbar.ax.set_ylabel(cbar_label)
            
    return ax
def plot_sampled_density(xall, yall, zall, ax=None, nbins=100, cmap=cm.viridis, cbar=True, cbar_label=None, vmin=None, vmax=None):
    return plot_sampled_function(xall, yall, zall, ax=ax, nbins=nbins, cmap=cmap, 
                                 cbar=cbar, cbar_label=cbar_label, vmin=vmin, vmax=vmax)
    
In [65]:
    
def show_eigenvec(first_n_ev):
    def proj_ev():
        proj_ev_all = []
        for i in range(min(32,n_clusters)):
            tmp_ = []
            xx = []
            yy = []
            eigval = M.eigenvalues()[i]
            eigvec = np.copy(M.eigenvectors_right()[:,i])
            eigvec = np.append(eigvec, 0)
            for dtraj_mapped in M.discrete_trajectories_active:  # loop over each discrete trajectory
                tmp_.append(eigvec[dtraj_mapped])
            proj_ev_all.append(np.hstack(tmp_))
        return proj_ev_all
    
    proj_ev_all = proj_ev()
    
    ncols = 4; nrows = int(np.ceil(first_n_ev / float(ncols)))
    plt.figure(figsize=(16,nrows*3))
    vmin=-2
    vmax=2
    for i in range(first_n_ev):
      if 1:
        plt.figure(figsize=(9,3))
      if 1:
        plt.subplot(1, 3, 1)
#        plot_sampled_function(tmp[:,0], tmp[:,1], np.vstack(proj_ev_all)[i], cbar=False, cmap='BrBG', vmin=vmin, vmax=vmax, nlevels=50)
        plot_sampled_function(np.vstack(Y)[:,0], np.vstack(Y)[:,1], proj_ev_all[i], cbar=False, cmap='BrBG', vmin=vmin, vmax=vmax, nlevels=50)
#        plot_sampled_function(tmp[:,0], tmp[:,1], (proj_ev_all), cbar=False, cmap='BrBG', vmin=vmin, vmax=vmax, nlevels=50)
        plt.gca().set_aspect('equal')
        plt.gca().set_axis_bgcolor('k')
        plt.ylabel(i)
      if 0:
        plt.subplot(1, 3, 2)
#        plot_sampled_function(tmp[:,0], tmp[:,2], np.vstack(proj_ev_all)[i], cbar=False, cmap='BrBG', vmin=vmin, vmax=vmax, nlevels=50)
        plot_sampled_function(np.vstack(Y)[:,0], np.vstack(Y)[:,2], proj_ev_all[i], cbar=False, cmap='BrBG', vmin=vmin, vmax=vmax, nlevels=50)
        plt.gca().set_aspect('equal')
        plt.gca().set_axis_bgcolor('k')
        plt.subplot(1, 3, 3)
        plot_sampled_function(np.vstack(Y)[:,1], np.vstack(Y)[:,2], proj_ev_all[i], cbar=False, cmap='BrBG', vmin=vmin, vmax=vmax, nlevels=50)
#        plot_sampled_function(tmp[:,1], tmp[:,2], np.vstack(proj_ev_all)[i], cbar=False, cmap='BrBG', vmin=vmin, vmax=vmax, nlevels=50)
        plt.gca().set_aspect('equal')
        plt.gca().set_axis_bgcolor('k')
      if 1:
        plt.figtext(0, .9, 'eig_{} = {}'.format(i, M.eigenvalues()[i]))
        if i>0:
            plt.figtext(0, .8, 'its_{} = {:.4f} ns'.format(i, dt*M.timescales()[i-1]))
        
        plt.tight_layout()
        plt.savefig('{}/eigenvec_lag{}_{:02d}.png'.format(activedir, msm_lag, i), dpi=144)
    return
plt.close('all')
show_eigenvec(4)
    
    
    
    
    
    
    
In [66]:
    
# Number of desired macrostates
n_sets = 10
M.pcca(n_sets)
pcca_dist = np.copy(M.metastable_distributions)  # P(state | metastable)
pcca_dist = np.append(pcca_dist, np.zeros((n_sets,1)), axis=1)
membership = np.copy(M.metastable_memberships)  # P(metastable | state)
membership = np.append(membership, np.zeros((1,n_sets)), axis=0)
# memberships over trajectory
dist_all = [np.hstack([pcca_dist[i,:][dtraj_m] for dtraj_m in M.discrete_trajectories_active]) for i in range(n_sets)]
mem_all = [np.hstack([membership[:,i][dtraj_m] for dtraj_m in M.discrete_trajectories_active]) for i in range(n_sets)]
    
In [68]:
    
def show_pcca_dens():
    def get_centroid(clustercenters, weights):
        active = np.copy(M.active_set)
        sz = np.copy(clustercenters)
        sz = sz[active, :]
        
        fz = sz[:,1] / 16.58  # [0,2)
        fz_P = fz[np.argmax(weights)]  # fz value of highest probability
        fz_lower = fz_P - 1
        new_fz = (fz - fz_lower) % 2 + fz_lower
                
        new_z = new_fz * 16.58
        
        sz[:,1] = new_z
        sz_w = sz * weights[:,np.newaxis]
        
        max_sz = np.sum(sz_w, axis=0)
        return max_sz
    
    ncols = 3; nrows = int(np.ceil(n_sets*3 / float(ncols)))
    ncols=2; nrows=1
    plt.close('all')
    plt.rc('font', size=10)
    
    vmin = 0
    vmax = 100
        
    for i in range(n_sets):
      if 1:
        plt.figure(figsize=(5,5))
        
        # DIST
        plt.subplot(1, ncols, 3*0+1)
        plt.gca().set_axis_bgcolor('k')
        plt.gca().set_aspect('equal')
        plot_sampled_density(np.vstack(Y)[:,0], np.vstack(Y)[:,1], dist_all[i], nbins=300, vmin=0, cmap=cm.viridis, cbar=False)
        scen, zcen = get_centroid(clustering.clustercenters, pcca_dist[i,0:-1])
        plt.scatter(scen, zcen, c='r')
        plt.ylabel('{} pcca_dist P(state|metastable) z={:.1f}'.format(i,zcen))
        plt.xticks([17, 19, 21, 23])
      if 1:
        # MEMB
        plt.subplot(1, ncols, 2)
        plt.gca().set_axis_bgcolor('k')
        plt.gca().set_aspect('equal')
        plot_sampled_density(np.vstack(Y)[:,0], np.vstack(Y)[:,1], mem_all[i], nbins=300, vmin=0, cmap=cm.viridis, cbar=False)
        #plt.scatter(*get_centroid(clustering.clustercenters, pcca_dist[i,0:-1]), c='r')
        plt.ylabel('Membership -- P(metastable {})'.format(i))
        plt.xticks([17, 19, 21, 23])
      if 1:
        plt.savefig('{}/eig_dens_lag{}_nsets{}_{:02d}.png'.format(activedir, msm_lag, n_sets, i), dpi=144)
    return
    
#show_pcca_dens(np.swapaxes(Y,0,1))
show_pcca_dens()
    
    
    
    
    
    
    
    
    
    
    
    
In [69]:
    
ck = M.cktest(n_sets, mlags=4, err_est=False)
    
In [70]:
    
try:
    plt.close('all')
    plt.figure()
    mplt.plot_cktest(ck, diag=True, figsize=(16,8), layout=(4, 3), 
                 padding_top=0.1, y01=True, padding_between=0.15, dt=.001, units='ns')
    plt.savefig('{}/cktest_nclust{}_lag{}_nsets{}.png'.format(activedir, n_clusters, msm_lag, n_sets), dpi=144)
except:
    print "Figure not created"
    pass
    
    
    
Now we want a coarse-grained kinetic model between these four metastable states. Coarse-graining of Markov models has been investigated by a number of researchers, so different approaches exist. It is certainly a bad a idea to just bin the clusters into four groups, e.g. using the PCCA memberships, and then re-estimate an MSM on these four states. This is going to be a very poor MSM, most likely it will not get timescales anywhere near those seen above and fail the CK-Test.
We recommend the following approach: Use the MSM and the metastable states computed by PCCA in order to estimate a four-state HMM. This can be simply achieved by calling coarse-grain on the MSM:
In [71]:
    
# Hidden Markov Model
hmm = M.coarse_grain(n_sets)
    
In [72]:
    
import matplotlib.patheffects as path_effects
# View stationary distribution of states
plt.close('all')
plt.figure(figsize=(10, 4))
plt.subplot(121)
plt.bar([i-.25 for i in range(n_sets)], hmm.stationary_distribution, width=.5)
# View transition matrix
plt.subplot(122)
plt.imshow(np.log10(hmm.transition_matrix), cmap=cm.viridis, vmin=-4, vmax=0, interpolation='nearest')
plt.gca().set_axis_bgcolor('k')
# Label with values
for i in range(n_sets):
    for j in range(n_sets):
        #print j,i,hmm.transition_matrix[j,i]
        if hmm.transition_matrix[j,i]>5e-4:
            val = '{:.3f}'.format(hmm.transition_matrix[j,i])
            text = plt.gca().text(i,j, val, color='white', ha='center', va='center', size=60//n_sets)
            text.set_path_effects([path_effects.Stroke(linewidth=3, foreground='black'), path_effects.Normal()])
plt.colorbar()
plt.savefig('{}/pcca_clusters{}_lag{}_nsets{}.png'.format(activedir, n_clusters, msm_lag, n_sets), dpi=144)
    
    
    
In [73]:
    
A, B = [3], [5]
#A, B = [4], [9]
A,B = [0,2,3,5], [1, 6, 7, 8]
tpt = msm.tpt(hmm, A,B)
print ('**forward committor**: ')
print (tpt.committor)
#print ('Gross TPT flux = ', tpt.gross_flux)
#print ('Net TPT flux = ', tpt.net_flux)
print ('')
print ('Total TPT flux = ', tpt.total_flux)
print ('Round trip time = ', msm_lag/tpt.total_flux)
#print ('Rate from TPT flux = ', tpt.rate)\
print ('')
print ('A->B transition time = ', msm_lag/tpt.rate)
print ('B->A transition time = ', msm_lag/msm.tpt(hmm, B,A).rate)
print ('Round trip time = ', msm_lag/msm.tpt(hmm, A,B).rate+msm_lag/msm.tpt(hmm, B,A).rate)
print ('')
print ('mfpt(A,B) = ', hmm.mfpt(A,B))
print ('mfpt(B,A) = ', hmm.mfpt(B,A))
print ('Round trip time = {}'.format(np.sum((hmm.mfpt(B, A), hmm.mfpt(A,B)))))
oneway = msm_lag/tpt.total_flux/2
oneway_ = np.sum((hmm.mfpt(B, A), hmm.mfpt(A,B)))/2
    
    
In [74]:
    
estD = 0.5*16.58**2/oneway * 1e-8  # m2/s
estD_ = 0.5*16.58**2/oneway_ * 1e-8  # m2/s
print (estD,estD_)
    
    
In [75]:
    
# we position states along the y-axis according to the commitor
tptpos = np.array([tpt.committor, [
                                   .0, # 0
                                   .2, # 1
                                   .3, # 2
                                   .5, # 3
                                   .4, # 4
                                   .9, # 5
                                   .1, # 6
                                   .6, # 7
                                   .7, # 8
                                   .8, # 9
                               #    0, # 10
                               #    0, # 11
                               #    0 # 12
                                  ]]).transpose()
tptpos =  np.array([tpt.committor, 1.0*np.arange(n_sets)[np.argsort(tpt.committor)]/n_sets]).transpose()
#print (tptpos)
minflux = 1e-7
plt.close('all')
print ('\n**Gross flux illustration**: ')
#mplt.plot_flux(tpt, pos=tptpos, arrow_label_format="%10.1e", attribute_to_plot='gross_flux', minflux=1e-5)
mplt.plot_flux(tpt, pos=tptpos, arrow_label_format="%10.2e", attribute_to_plot='gross_flux', minflux=minflux, fontsize=44)
plt.savefig('{}/tpt_gross_flux_clusters{}_lag{}_A{}_B{}.png'.format(activedir, n_clusters, msm_lag, A, B), dpi=144)
plt.close('all')
print ('\n**Net flux illustration**: ')
mplt.plot_flux(tpt, pos=tptpos, arrow_label_format="%10.2e", attribute_to_plot='net_flux', minflux=minflux, fontsize=44)
plt.savefig('{}/tpt_net_flux_clusters{}_lag{}_A{}_B{}.png'.format(activedir, n_clusters, msm_lag, A, B), dpi=144)
plt.close('all')
print ('\n**Net percentage flux illustration**: ')
mplt.plot_flux(tpt, pos=tptpos, flux_scale=100.0/tpt.total_flux, arrow_label_format="%3.1f", minflux=minflux, fontsize=44)
plt.title('Estimated Dz from tpt rate: {:.2e} m2 s-1'.format(estD))
plt.savefig('{}/tpt_pct_flux_clusters{}_lag{}_A{}_B{}.png'.format(activedir, n_clusters, msm_lag, A, B), dpi=144)
    
    
    
In [ ]: