In [1]:
import numpy as np
import numpy.random as npr
import matplotlib.pyplot as plt
%matplotlib inline
import mdtraj
plt.rc('font', family='serif')

from msmbuilder.example_datasets import FsPeptide
dataset = FsPeptide().get()
fs_trajectories = dataset.trajectories
fs_t = fs_trajectories[0]


loading trajectory_1.xtc...
loading trajectory_10.xtc...
loading trajectory_11.xtc...
loading trajectory_12.xtc...
loading trajectory_13.xtc...
loading trajectory_14.xtc...
loading trajectory_15.xtc...
loading trajectory_16.xtc...
loading trajectory_17.xtc...
loading trajectory_18.xtc...
loading trajectory_19.xtc...
loading trajectory_2.xtc...
loading trajectory_20.xtc...
loading trajectory_21.xtc...
loading trajectory_22.xtc...
loading trajectory_23.xtc...
loading trajectory_24.xtc...
loading trajectory_25.xtc...
loading trajectory_26.xtc...
loading trajectory_27.xtc...
loading trajectory_28.xtc...
loading trajectory_3.xtc...
loading trajectory_4.xtc...
loading trajectory_5.xtc...
loading trajectory_6.xtc...
loading trajectory_7.xtc...
loading trajectory_8.xtc...
loading trajectory_9.xtc...

In [38]:


In [40]:


In [44]:


In [46]:


In [15]:
from msmbuilder.example_datasets import AlanineDipeptide
ala = AlanineDipeptide().get()
t = ala.trajectories[1]

In [16]:
from msmbuilder.featurizer import DihedralFeaturizer
dih_model = DihedralFeaturizer()
traj = dih_model.fit_transform([t])

In [26]:
from msmbuilder.decomposition import tICA
tica = tICA(lag_time=100,n_components=2)
tica.fit(traj)
X = tica.transform(traj)

In [27]:
from scipy.spatial.distance import pdist
raw_dist = pdist(traj[0])
tica_dist = pdist(X[0])

In [ ]:
from sklearn.decomposition import

In [ ]:
plt.scatter(

In [19]:
len(traj[0])


Out[19]:
10000

In [ ]:


In [28]:
from scipy.spatial.distance import squareform

def plot_dist_mat(dist_mat,filename):
    plt.imshow(dist_mat,interpolation='none',cmap='Blues');
    plt.savefig(filename,dpi=300)

In [30]:
d = squareform(raw_dist)[::2,::2]

plot_dist_mat(d,'raw_dist_mat.png')



In [ ]:
d = squareform(pdist(np.vstack(X_tica)))

plot_dist_mat(d,'tica_dist_mat.png')

In [6]:
# create a combination trajectory from 10 of the individual trajectories
fs_group = [ind[::15] for ind in fs_trajectories[:10]]

len(fs_group)*len(fs_group[0])


Out[6]:
6670

In [52]:
# concatenate them into a single trajectory
fs_group_traj = fs_group[0]
for ind in fs_group[1:]:
    fs_group_traj = fs_group_traj + ind
fs_traj = dih_model.fit_transform([fs_group_traj])
raw_dist = pdist(fs_traj[0])

d = squareform(raw_dist)

plot_dist_mat(d,'fs_raw_dist_mat.png')



In [81]:
# compute RMSD distance matrix

rmsd_mat = np.zeros((len(fs_group_traj),len(fs_group_traj)))

for i in range(1,len(fs_group_traj)):
    rmsd_mat[i,:i] = rmsd(fs_group_traj[:i],fs_group_traj,i)

In [ ]:
rms

In [195]:
rmsd_mat = rmsd_mat + rmsd_mat.T

In [196]:
plot_dist_mat(rmsd_mat,'fs_rmsd.png')



In [63]:
xyz = fs_group_traj.xyz

In [65]:
xyz.shape


Out[65]:
(6670, 264, 3)

In [95]:
# compute weights
#from MDAnalysis.analysis import align
from mdtraj.geometry import alignment

def compute_atomwise_deviation_xyz(X_xyz,Y_xyz):
    X_prime = alignment.transform(X_xyz, Y_xyz)
    delta = X_prime - Y_xyz
    deviation = ((delta**2).sum(1))**0.5
    return deviation

def compute_atomwise_deviation(X,Y):
    return compute_atomwise_deviation_xyz(X.xyz[0],Y.xyz[0])

tau=20
deviations = []

for i,fs_t in enumerate(fs_trajectories):
    print(i)
    n_frames=len(fs_t)-tau
    n_atoms = fs_t.n_atoms
    atomwise_deviations=np.zeros((n_frames,n_atoms))
    for i in range(n_frames):
        atomwise_deviations[i] = compute_atomwise_deviation(fs_t[i],fs_t[i+tau])
    deviations.append(atomwise_deviations)


0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27

In [96]:
# compute weights as inverse mean deviation
mean = np.mean(np.vstack(deviations),0)
weights = 1/mean

In [98]:
np.save('fs_atomwise_deviations_tau=20.npy',np.vstack(deviations))

In [100]:
np.vstack(deviations).shape


Out[100]:
(279440, 264)

In [99]:
mean = np.mean(np.vstack(deviations),0)
stdev = np.std(np.vstack(deviations),0)
plt.plot(mean,c='darkblue')
#plt.plot(mean+stdev,c='blue',linestyle='--')
#plt.plot(mean-stdev,c='blue',linestyle='--')
plt.fill_between(range(len(mean)),mean-stdev,mean+stdev,color='blue',alpha=0.3)
plt.xlabel('Atom index')
plt.ylabel(r'Mean displacement, computed at $\tau=1$ns')
plt.title('Atomwise deviation from Fs peptide simulation')
plt.xlim(0,263)
plt.savefig('atomwise_deviation.pdf')



In [72]:
# compute weighted RMSD distance matrix


from MDAnalysis.analysis import align
learned_weights = weights

wrmsd_matrix = np.zeros((len(xyz),len(xyz)))

for i in range(len(xyz)):
    if i % 100 == 0:
        print(i)
    for j in range(i):
        wrmsd_matrix[i,j] = align.rms.rmsd(xyz[i],xyz[j],weights)

In [105]:
learned_weights = weights

In [ ]:


In [ ]:
# compute autocorrelation plots

# see below!

In [ ]:
# use these distances in k-medoids clustering

In [104]:
from alt_kmed import AltMiniBatchKMedoids

In [106]:
def wrmsd(X,Y):
    return align.rms.rmsd(X,Y,learned_weights)

In [107]:
kmed = AltMiniBatchKMedoids(metric=('callable',wrmsd))

In [108]:
kmed.fit(xyz)


Out[108]:
AltMiniBatchKMedoids(batch_size=100, max_iter=5, max_no_improvement=10,
           metric=('callable', <function wrmsd at 0x174355e60>),
           n_clusters=8, random_state=None)

In [109]:
all_xyz = np.vstack([ind.xyz for ind in fs_trajectories])

In [110]:
all_xyz.shape


Out[110]:
(280000, 264, 3)

In [111]:
clusters = kmed.transform(all_xyz)

In [119]:
clusters = np.array(clusters,dtype=int)

In [120]:
np.save('wrmsd_clusters.npy',clusters)

In [121]:
clusters_list = []
ind = 0
for i in range(len(fs_trajectories)):
    length = len(fs_trajectories[i])
    clusters_list.append(clusters[ind:ind+length])
    ind+= length

In [2]:
import pyemma

In [3]:
deviations = np.load('fs_atomwise_deviations_tau=20.npy')
mean = np.mean(deviations)
weights = 1/mean

In [4]:
from msmbuilder.featurizer import DihedralFeaturizer
dih_model = DihedralFeaturizer()
X = dih_model.fit_transform(fs_trajectories)

In [5]:
X_tica = pyemma.coordinates.tica(X,lag=20).transform(X)

In [ ]:


In [122]:
import pyemma
from pyemma.msm import its

In [129]:
lags = [1,2,3,4,5,10,20,50,100,200,300,400,500,1000]
implied_timescales = its(clusters_list,lags,nits=1,errors='bayes')

In [136]:
implied_timescales.sample_std


Out[136]:
array([[   0.72403819],
       [   1.33682384],
       [   2.26689548],
       [   2.8437467 ],
       [   3.30752266],
       [   5.84986138],
       [   8.42323087],
       [  17.40073782],
       [  27.62846603],
       [  65.83622162],
       [  74.91780514],
       [  99.17232957],
       [ 138.26448743],
       [ 230.15441369]])

In [160]:
wrmsd_timescales = implied_timescales
wrmsd_mean = wrmsd_timescales.timescales[:,0]
wrmsd_std = wrmsd_timescales.sample_std[:,0]

In [157]:
# unweighted rmsd
kmed_ = msmbuilder.cluster.MiniBatchKMedoids(metric='rmsd')
rmsd_clusters = kmed_.fit_transform(fs_trajectories)

rmsd_timescales = its(rmsd_clusters,lags,nits=1,errors='bayes')

rmsd_mean = rmsd_timescales.timescales[:,0]
rmsd_std = rmsd_timescales.sample_std[:,0]

In [188]:
from numpy.linalg import det
def BC(X,Y):
    return 1 - det(X.T.dot(Y)) / np.sqrt(det(X.T.dot(X)) * det(Y.T.dot(Y)))

In [189]:
# unweighted binet-cauchy

kmed_ = AltMiniBatchKMedoids(metric=('callable',BC))
kmed_.fit(xyz)
bc_clusters = kmed_.transform(all_xyz)


bc_clusters = np.array(bc_clusters,dtype=int)
clusters_list = []
ind = 0
for i in range(len(fs_trajectories)):
    length = len(fs_trajectories[i])
    clusters_list.append(bc_clusters[ind:ind+length])
    ind+= length

bc_timescales = its(clusters_list,lags,nits=1,errors='bayes')

bc_mean = bc_timescales.timescales[:,0]
bc_std = bc_timescales.sample_std[:,0]

In [190]:
def BC_w(X,Y,M=None):
    if M==None:
        M = np.diag(np.ones(len(X)))
        
    return det(X.T.dot(M).dot(Y)) / np.sqrt(det(X.T.dot(M).dot(X)) * det(Y.T.dot(M).dot(Y)))

In [191]:
def weighted_BC(X,Y):
    M = np.diag(learned_weights)
    return 1 - (det(X.T.dot(M).dot(Y)) / np.sqrt(det(X.T.dot(M).dot(X)) * det(Y.T.dot(M).dot(Y))))

In [192]:
# weighted binet-cauchy

kmed_ = AltMiniBatchKMedoids(metric=('callable',weighted_BC))
kmed_.fit(xyz)
wbc_clusters = kmed_.transform(all_xyz)


wbc_clusters = np.array(wbc_clusters,dtype=int)
clusters_list = []
ind = 0
for i in range(len(fs_trajectories)):
    length = len(fs_trajectories[i])
    clusters_list.append(wbc_clusters[ind:ind+length])
    ind+= length

wbc_timescales = its(clusters_list,lags,nits=1)#,errors='bayes')

wbc_mean = wbc_timescales.timescales[:,0]
wbc_std = wbc_timescales.sample_std[:,0]


---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-192-a5cb8293bb1b> in <module>()
     17 
     18 wbc_mean = wbc_timescales.timescales[:,0]
---> 19 wbc_std = wbc_timescales.sample_std[:,0]

/Users/joshuafass/anaconda/envs/py27/lib/python2.7/site-packages/pyEMMA-2.0-py2.7-macosx-10.5-x86_64.egg/pyemma/msm/estimators/implied_timescales.pyc in sample_std(self)
    318 
    319         """
--> 320         return self.get_sample_std()
    321 
    322     def get_sample_std(self, process=None):

/Users/joshuafass/anaconda/envs/py27/lib/python2.7/site-packages/pyEMMA-2.0-py2.7-macosx-10.5-x86_64.egg/pyemma/msm/estimators/implied_timescales.pyc in get_sample_std(self, process)
    340         if self._its_samples is None:
    341             raise RuntimeError('Cannot compute sample mean, because no samples were generated ' +
--> 342                                ' try calling bootstrap() before')
    343         # OK, go:
    344         if process is None:

RuntimeError: Cannot compute sample mean, because no samples were generated  try calling bootstrap() before

In [146]:
# raw
kmed_ = msmbuilder.cluster.MiniBatchKMedoids()
raw_clusters = kmed_.fit_transform(dih_model.fit_transform(fs_trajectories))

raw_timescales = its(raw_clusters,lags,nits=1,errors='bayes')

raw_mean = raw_timescales.timescales[:,0]
raw_std = raw_timescales.sample_std[:,0]

In [152]:
# tica
from pyemma.coordinates import tica

kmed_ = msmbuilder.cluster.MiniBatchKMedoids()
X = dih_model.fit_transform(fs_trajectories)
X_tica = tica(X,lag=20).transform(X)
tica_clusters = kmed_.fit_transform(X_tica)

tica_timescales = its(tica_clusters,lags,nits=1,errors='bayes')

tica_mean = tica_timescales.timescales[:,0]
tica_std = tica_timescales.sample_std[:,0]

In [ ]:
# tica pairwise

kmed_ = msmbuilder.cluster.MiniBatchKMedoids()
feat = pyemma.coordinates.data.featurizer

X = dih_model.fit_transform(fs_trajectories)
X_tica = tica(X,lag=20).transform(X)
tica_clusters = kmed_.fit_transform(X_tica)



tica_timescales = its(tica_clusters,lags,nits=1,errors='bayes')

tica_mean = tica_timescales.timescales[:,0]
tica_std = tica_timescales.sample_std[:,0]

In [193]:
# Raw Dihedral features
mean = raw_mean
std = raw_std
plt.plot(lags,mean,c='darkgrey',label='Raw dihedral features')
plt.fill_between(lags,mean-std,mean+std,alpha=0.3,color='grey')

# tICA (dihedral)
mean = tica_mean
std = tica_std
plt.plot(lags,mean,c='darkred',label='tICA-transformed dihedral features')
plt.fill_between(lags,mean-std,mean+std,alpha=0.3,color='red')

# Binet-Cauchy kernel
mean = bc_mean
std = bc_std

plt.plot(lags,mean,c='purple',linestyle='-',label='Binet-Cauchy kernel')
plt.fill_between(lags,mean-std,mean+std,alpha=0.3,color='purple')

# RMSD
mean = rmsd_mean
std = rmsd_std

plt.plot(lags,mean,c='darkgreen',linestyle=':',label='RMSD')
plt.fill_between(lags,mean-std,mean+std,alpha=0.3,color='green')

# Weighted RMSD
mean = wrmsd_mean
std = wrmsd_std

plt.plot(lags,mean,c='darkblue',linestyle='--',label='Weighted RMSD')
plt.fill_between(lags,mean-std,mean+std,alpha=0.3,color='blue')


plt.legend(loc='best')

plt.yscale('log')
plt.xlabel('lag time / steps')
plt.ylabel('timescale / steps')

plt.title('Implied timescales of MSMs built using various metrics')

plt.savefig('implied_timescales_.pdf')



In [143]:
std.shape,mean.shape,len(lags)


Out[143]:
((14,), (14,), 14)

In [126]:
pyemma.plots.plot_implied_timescales(implied_timescales)


Out[126]:
<matplotlib.axes._subplots.AxesSubplot at 0x1cdb066d0>

In [ ]:
implied_timescales

In [ ]:


In [3]:
from mdtraj import rmsd

In [4]:
rmsd(fs_t[:10],fs_t[0])


Out[4]:
array([ 0.        ,  0.19644617,  0.21715195,  0.24179077,  0.2577042 ,
        0.19864696,  0.19607922,  0.20785291,  0.25813895,  0.29692328], dtype=float32)

In [5]:
def rmsd_diffs(traj,lag=10):
    ''' at each frame, calculate the rmsd to the next 10 or so frames'''
    all_rmsd_diffs = np.zeros((traj.n_frames-lag,lag))
    for i in range(len(all_rmsd_diffs)):
        all_rmsd_diffs[i] = rmsd(traj[i:(lag+i)],traj[i])
    return all_rmsd_diffs

In [6]:
diffs = rmsd_diffs(fs_t,100)

In [7]:
plt.plot(diffs.sum(1))


Out[7]:
[<matplotlib.lines.Line2D at 0x1469fe8d0>]

In [8]:
plt.plot(diffs.max(1))


Out[8]:
[<matplotlib.lines.Line2D at 0x146911390>]

In [9]:
plt.plot(diffs[:,1:].min(1))


Out[9]:
[<matplotlib.lines.Line2D at 0x146876510>]

In [10]:
diffs.shape


Out[10]:
(9900, 100)

In [11]:
plt.plot(np.median(diffs[:,1:],1))


Out[11]:
[<matplotlib.lines.Line2D at 0x1467e2050>]

In [12]:
def ma(data,n=50):
    ''' compute moving average'''
    w = np.array([1.0/n]*n)
    return np.convolve(data,w,'valid')

In [13]:
diffs.shape


Out[13]:
(9900, 100)

In [14]:
plt.plot(ma(np.median(diffs[:,1:],1)))
plt.figure()
plt.plot(ma(np.mean(diffs[:,1:],1)))
plt.figure()
plt.plot(ma(np.max(diffs[:,1:],1)))
#plt.plot(ma(np.min(diffs[:,1:],1)))


Out[14]:
[<matplotlib.lines.Line2D at 0x148771dd0>]

In [15]:
diffs_t = rmsd_diffs(t,100)

In [16]:
n=500
plt.plot(ma(np.median(diffs_t[:,1:],1),n))
plt.figure()
plt.plot(ma(np.mean(diffs_t[:,1:],1),n))
plt.figure()
plt.plot(ma(np.max(diffs_t[:,1:],1),n))
#plt.plot(ma(np.min(diffs[:,1:],1)))


Out[16]:
[<matplotlib.lines.Line2D at 0x1493c8c90>]

In [17]:
ns = [100,200,500,1000]
for n in ns:
    plt.plot(range(n-1,len(diffs_t)),ma(np.median(diffs_t[:,1:],1),n))



In [18]:
ma(np.median(diffs_t[:,1:],1),n).shape


Out[18]:
(8900,)

In [19]:
n


Out[19]:
1000

In [20]:
len(diffs_t)-n


Out[20]:
8899

In [21]:
plt.plot(diffs_t.mean(0)[1:],label='Mean')
plt.plot(np.median(diffs_t[:,1:],0),label='Median')
#plt.plot(np.min(diffs_t[:,1:],0),label='Min')
#plt.plot(np.max(diffs_t[:,1:],0),label='Max')
plt.xlabel(r'Time lag ($\tau$)')
plt.ylabel('RMSD')
plt.legend(loc='best')
plt.title('Alanine Dipeptide')


Out[21]:
<matplotlib.text.Text at 0x1487d0e90>

In [22]:
diffs_t.shape


Out[22]:
(9899, 100)

In [23]:
plt.plot(diffs.mean(0)[1:],label='Mean')
plt.plot(np.median(diffs[:,1:],0),label='Median')
#plt.plot(np.min(diffs[:,1:],0),label='Min')
#plt.plot(np.max(diffs[:,1:],0),label='Max')
plt.xlabel(r'Time lag ($\tau$)')
plt.ylabel('RMSD')
plt.legend(loc='best')
plt.title('Fs Peptide')


Out[23]:
<matplotlib.text.Text at 0x1496ffdd0>

In [24]:
plt.hist(diffs[:,-1],bins=50);
plt.figure()
plt.hist(diffs_t[:,-1],bins=50);



In [25]:
import sys
sys.path.append('../projects/metric-learning')
import weighted_rmsd
reload(weighted_rmsd)


Out[25]:
<module 'weighted_rmsd' from '../projects/metric-learning/weighted_rmsd.pyc'>

In [26]:
from weighted_rmsd import wRMSD,compute_kinetic_weights

In [204]:
weights = compute_kinetic_weights(t,10)
weights_10 = compute_kinetic_weights(t,10)
weights_100 = compute_kinetic_weights(t,100)

In [227]:
def windowed_deviations(traj,lags=np.arange(10)*10+1):
    ''' at each frame, calculate the atomwise deviations to the next 10 or so frames'''
    all_dev_diffs = np.zeros((traj.n_frames-max(lags),len(lags),traj.n_atoms))
    for i in range(len(all_dev_diffs)):
        for j,tau in enumerate(lags):
            all_dev_diffs[i,j-1] = weighted_rmsd.compute_atomwise_deviation(traj[i+j],traj[i])
        if i % 1000 == 0:
            print(i)
    return all_dev_diffs

In [228]:
dev_t = windowed_deviations(t)


0
1000
2000
3000
4000
5000
6000
7000
8000
9000

In [229]:
dev_t.shape


Out[229]:
(9908, 10, 22)

In [30]:
dev_t.shape


Out[30]:
(9899, 100, 22)

In [31]:
dev_t.dot(weights).shape


Out[31]:
(9899, 100)

In [32]:
%timeit dev_t.dot(weights)


10 loops, best of 3: 34.8 ms per loop

In [47]:
weights = (weights/np.sqrt(sum(weights**2)))
weights.mean()


Out[47]:
0.18414783968453527

In [636]:
def norm(w):
    return np.abs(w)/np.sqrt(np.sum(w**2))

In [48]:
def plot_diffs(diffs,title='',label=''):
    if len(label)>0:
        plt.plot(diffs.mean(0)[1:],label=label)
    else:
        plt.plot(diffs.mean(0)[1:])
    plt.xlabel(r'Time lag ($\tau$)')
    plt.ylabel('RMSD')

    if len(title)>0:
        plt.title(title)

In [49]:
n_atoms = len(weights)

In [52]:
plot_diffs(dev_t.dot(norm(weights))/n_atoms,label='Weighted')
plot_diffs(dev_t.dot(norm(np.ones(dev_t.shape[-1])))/n_atoms,label='Unweighted')
#plot_diffs(diffs,label='Unweighted')
plt.legend(loc='best')


Out[52]:
<matplotlib.legend.Legend at 0x14a097c50>

In [116]:
def objective(weights):
    normalized_weights = norm(weights)
    return np.sum(np.dot(dev_t,normalized_weights)**2)*np.min(weights >= 0)

In [117]:
def lnprob(weights):
    return -np.log(objective(weights))

In [118]:
ones = np.ones(len(weights))
print(objective(ones),objective(weights))
print(lnprob(ones),lnprob(weights))


(142357.34728641112, 53677.060161727517)
(-11.866095705609629, -10.890741004157169)

In [119]:
%timeit objective(weights)


10 loops, best of 3: 35.2 ms per loop

In [120]:
import emcee

In [121]:
weights.shape


Out[121]:
(22,)

In [122]:
lnprob(weights+npr.randn(n_atoms)*0.01)


Out[122]:
-10.871543027811178

In [123]:
from autograd import grad
import autograd.numpy as np

In [113]:
grad_objective = grad(objective)

In [114]:
grad_objective(weights)


Out[114]:
array([ 53223.56740511, -17864.30448145,  53366.86830639,  53259.49151462,
       -11385.08108938,  29740.02818509, -23251.88545877,  15193.74002534,
       -29306.54675912,  -8842.7257977 , -17290.10163157,  31061.0076599 ,
        30729.35418509,  30585.41374849, -11605.92230607,  39055.97039275,
       -14518.18901346,  29112.47084265, -13450.6359189 ,  53865.61097045,
        53557.65771912,  53769.56297641])

In [ ]:


In [130]:
n_walkers = n_atoms*2
sampler = emcee.EnsembleSampler(n_walkers, n_atoms, lambda w:-objective(w))
sampler.run_mcmc(np.random.randn(n_walkers,n_atoms), 100)


Out[130]:
(array([[  2.04838002e+19,   1.09791106e+19,  -2.40203009e+18,
           1.47598905e+19,  -2.37245109e+19,   4.46252773e+18,
          -1.20075333e+19,  -3.47695230e+18,  -1.13709004e+19,
           1.58534115e+18,  -2.39525681e+18,  -1.63931874e+19,
           4.96616941e+18,  -1.41884309e+19,  -1.25881750e+19,
          -1.45922105e+19,   1.08999834e+19,  -7.36863242e+18,
           4.35265354e+18,   1.57625999e+19,   1.55486221e+19,
          -9.69565931e+18],
        [  7.38589319e+19,   3.95876571e+19,  -8.66109871e+18,
           5.32200709e+19,  -8.55440268e+19,   1.60906640e+19,
          -4.32958544e+19,  -1.25369218e+19,  -4.10003237e+19,
           5.71630094e+18,  -8.63664965e+18,  -5.91093170e+19,
           1.79066463e+19,  -5.11595317e+19,  -4.53894765e+19,
          -5.26154607e+19,   3.93023493e+19,  -2.65692481e+19,
           1.56944894e+19,   5.68356041e+19,   5.60640377e+19,
          -3.49598815e+19],
        [  9.51447067e+19,   5.09966035e+19,  -1.11571230e+19,
           6.85578661e+19,  -1.10197415e+20,   2.07278807e+19,
          -5.57735151e+19,  -1.61500025e+19,  -5.28164124e+19,
           7.36369906e+18,  -1.11256958e+19,  -7.61443487e+19,
           2.30672372e+19,  -6.59034850e+19,  -5.84705106e+19,
          -6.77790125e+19,   5.06290556e+19,  -3.42263903e+19,
           2.02175171e+19,   7.32153191e+19,   7.22214266e+19,
          -4.50351387e+19],
        [  2.33784767e+20,   1.25306285e+20,  -2.74147620e+19,
           1.68456876e+20,  -2.70771479e+20,   5.09315212e+19,
          -1.37043834e+20,  -3.96829704e+19,  -1.29777809e+20,
           1.80937251e+19,  -2.73375025e+19,  -1.87097996e+20,
           5.66796473e+19,  -1.61934712e+20,  -1.43670759e+20,
          -1.66543109e+20,   1.24403183e+20,  -8.40993225e+19,
           4.96775099e+19,   1.79900993e+20,   1.77458823e+20,
          -1.10658060e+20],
        [  1.01756772e+20,   5.45406171e+19,  -1.19325239e+19,
           7.33222606e+19,  -1.17855549e+20,   2.21683828e+19,
          -5.96494742e+19,  -1.72723442e+19,  -5.64868837e+19,
           7.87544388e+18,  -1.18988706e+19,  -8.14359698e+19,
           2.46702977e+19,  -7.04834214e+19,  -6.25339028e+19,
          -7.24892776e+19,   5.41475333e+19,  -3.66049347e+19,
           2.16225623e+19,   7.83034196e+19,   7.72404334e+19,
          -4.81648503e+19],
        [  8.10184832e+19,   4.34250822e+19,  -9.50057164e+18,
           5.83790603e+19,  -9.38363352e+19,   1.76503779e+19,
          -4.74927718e+19,  -1.37522058e+19,  -4.49747341e+19,
           6.27039263e+18,  -9.47384368e+18,  -6.48391415e+19,
           1.96424133e+19,  -5.61187684e+19,  -4.97893702e+19,
          -5.77158420e+19,   4.31120851e+19,  -2.91447853e+19,
           1.72157750e+19,   6.23449575e+19,   6.14986564e+19,
          -3.83487247e+19],
        [  5.18176042e+19,   2.77737053e+19,  -6.07634072e+18,
           3.73379305e+19,  -6.00156090e+19,   1.12887762e+19,
          -3.03753066e+19,  -8.79560615e+18,  -2.87648249e+19,
           4.01040627e+18,  -6.05925829e+18,  -4.14696483e+19,
           1.25628396e+19,  -3.58923173e+19,  -3.18441567e+19,
          -3.69137535e+19,   2.75735189e+19,  -1.86403465e+19,
           1.10108217e+19,   3.98744314e+19,   3.93331580e+19,
          -2.45269794e+19],
        [  2.11103064e+19,   1.13149146e+19,  -2.47551354e+18,
           1.52113178e+19,  -2.44501286e+19,   4.59902332e+18,
          -1.23747921e+19,  -3.58328877e+18,  -1.17186767e+19,
           1.63382620e+18,  -2.46853190e+18,  -1.68945871e+19,
           5.11806461e+18,  -1.46223724e+19,  -1.29731833e+19,
          -1.50385066e+19,   1.12333691e+19,  -7.59400070e+18,
           4.48578774e+18,   1.62447129e+19,   1.60241834e+19,
          -9.99220933e+18],
        [  1.51929363e+20,   8.14326178e+19,  -1.78159853e+19,
           1.09474824e+20,  -1.75965867e+20,   3.30987908e+19,
          -8.90604703e+19,  -2.57887223e+19,  -8.43385234e+19,
           1.17585429e+19,  -1.77657762e+19,  -1.21589094e+20,
           3.68343147e+19,  -1.05236285e+20,  -9.33671102e+19,
          -1.08231130e+20,   8.08457105e+19,  -5.46535064e+19,
           3.22838536e+19,   1.16911996e+20,   1.15324903e+20,
          -7.19131891e+19],
        [  9.06162948e+18,   4.85694112e+18,  -1.06260416e+18,
           6.52949353e+18,  -1.04952632e+19,   1.97413205e+18,
          -5.31189909e+18,  -1.53813478e+18,  -5.03026543e+18,
           7.01319236e+17,  -1.05961709e+18,  -7.25203118e+18,
           2.19693416e+18,  -6.27668429e+18,  -5.56876639e+18,
          -6.45531439e+18,   4.82193190e+18,  -3.25974221e+18,
           1.92552133e+18,   6.97306197e+18,   6.87840754e+18,
          -4.28916921e+18],
        [  9.54145978e+19,   5.11412719e+19,  -1.11888139e+19,
           6.87523094e+19,  -1.10509983e+20,   2.07866976e+19,
          -5.59317214e+19,  -1.61958058e+19,  -5.29662218e+19,
           7.38459081e+18,  -1.11572676e+19,  -7.63603353e+19,
           2.31326793e+19,  -6.60903990e+19,  -5.86363535e+19,
          -6.79712435e+19,   5.07726929e+19,  -3.43234633e+19,
           2.02748940e+19,   7.34230196e+19,   7.24262852e+19,
          -4.51628926e+19],
        [  3.86833060e+19,   2.07338588e+19,  -4.53617669e+18,
           2.78738038e+19,  -4.48033359e+19,   8.42739687e+18,
          -2.26760238e+19,  -6.56616398e+18,  -2.14737457e+19,
           2.99388414e+18,  -4.52341586e+18,  -3.09582605e+19,
           9.37851731e+18,  -2.67946159e+19,  -2.37725520e+19,
          -2.75571429e+19,   2.05844197e+19,  -1.39155379e+19,
           8.21990030e+18,   2.97673958e+19,   2.93633104e+19,
          -1.83100849e+19],
        [  9.53796661e+19,   5.11225577e+19,  -1.11847551e+19,
           6.87271187e+19,  -1.10469508e+20,   2.07791084e+19,
          -5.59112356e+19,  -1.61898780e+19,  -5.29468339e+19,
           7.38190184e+18,  -1.11531593e+19,  -7.63323578e+19,
           2.31242253e+19,  -6.60661868e+19,  -5.86148789e+19,
          -6.79463398e+19,   5.07541328e+19,  -3.43108878e+19,
           2.02675008e+19,   7.33961497e+19,   7.23997585e+19,
          -4.51463603e+19],
        [  4.13741582e+18,   2.21761410e+18,  -4.85180260e+17,
           2.98126930e+18,  -4.79198725e+18,   9.01365685e+17,
          -2.42534008e+18,  -7.02289078e+17,  -2.29674618e+18,
           3.20214994e+17,  -4.83807758e+17,  -3.31117554e+18,
           1.00309293e+18,  -2.86584017e+18,  -2.54261705e+18,
          -2.94740006e+18,   2.20163414e+18,  -1.48834997e+18,
           8.79172567e+17,   3.18380757e+18,   3.14058441e+18,
          -1.95837713e+18],
        [  2.73891728e+19,   1.46803222e+19,  -3.21178188e+18,
           1.97356513e+19,  -3.17223729e+19,   5.96690166e+18,
          -1.60554461e+19,  -4.64907726e+18,  -1.52041803e+19,
           2.11977455e+18,  -3.20275344e+18,  -2.19195722e+19,
           6.64032820e+18,  -1.89715421e+19,  -1.68318190e+19,
          -1.95114447e+19,   1.45745091e+19,  -9.85270316e+18,
           5.81998522e+18,   2.10763923e+19,   2.07902823e+19,
          -1.29642024e+19],
        [  8.04392205e+19,   4.31146073e+19,  -9.43268505e+18,
           5.79616251e+19,  -9.31653973e+19,   1.75241967e+19,
          -4.71532074e+19,  -1.36538726e+19,  -4.46531541e+19,
           6.22557387e+18,  -9.40610963e+18,  -6.43755397e+19,
           1.95019870e+19,  -5.57175016e+19,  -4.94333602e+19,
          -5.73031502e+19,   4.28038726e+19,  -2.89363889e+19,
           1.70927128e+19,   6.18992172e+19,   6.10589438e+19,
          -3.80745443e+19],
        [  7.99595079e+19,   4.28574903e+19,  -9.37642369e+18,
           5.76159612e+19,  -9.26097932e+19,   1.74196826e+19,
          -4.68720044e+19,  -1.35724440e+19,  -4.43868553e+19,
           6.18843708e+18,  -9.35003727e+18,  -6.39916377e+19,
           1.93856741e+19,  -5.53852215e+19,  -4.91385567e+19,
          -5.69614125e+19,   4.25485916e+19,  -2.87638245e+19,
           1.69907712e+19,   6.15300734e+19,   6.06948129e+19,
          -3.78474815e+19],
        [  5.15195305e+19,   2.76139449e+19,  -6.04142511e+18,
           3.71231228e+19,  -5.96703555e+19,   1.12238574e+19,
          -3.02005709e+19,  -8.74500188e+18,  -2.85993486e+19,
           3.98734298e+18,  -6.02440836e+18,  -4.12310873e+19,
           1.24905826e+19,  -3.56858270e+19,  -3.16609622e+19,
          -3.67013814e+19,   2.74149274e+19,  -1.85331043e+19,
           1.09475086e+19,   3.96450696e+19,   3.91068913e+19,
          -2.43858941e+19],
        [  1.06070511e+20,   5.68527277e+19,  -1.24383521e+19,
           7.64306014e+19,  -1.22851766e+20,   2.31081486e+19,
          -6.21781720e+19,  -1.80045736e+19,  -5.88815192e+19,
           8.20931410e+18,  -1.24032667e+19,  -8.48882451e+19,
           2.57161384e+19,  -7.34714237e+19,  -6.51848871e+19,
          -7.55623105e+19,   5.64429885e+19,  -3.81567211e+19,
           2.25391881e+19,   8.16228954e+19,   8.05148602e+19,
          -5.02066799e+19],
        [  7.87433334e+19,   4.22056511e+19,  -9.23389116e+18,
           5.67395733e+19,  -9.12011686e+19,   1.71547701e+19,
          -4.61590838e+19,  -1.33659849e+19,  -4.37117115e+19,
           6.09430953e+18,  -9.20786654e+18,  -6.30183329e+19,
           1.90908333e+19,  -5.45427519e+19,  -4.83911350e+19,
          -5.60949695e+19,   4.19014658e+19,  -2.83263018e+19,
           1.67323910e+19,   6.05942397e+19,   5.97716416e+19,
          -3.72718382e+19],
        [  6.30137619e+19,   3.37747373e+19,  -7.38927848e+18,
           4.54054772e+19,  -7.29830929e+19,   1.37279444e+19,
          -3.69384615e+19,  -1.06960493e+19,  -3.49799978e+19,
           4.87692424e+18,  -7.36848694e+18,  -5.04299503e+19,
           1.52772848e+19,  -4.36474909e+19,  -3.87246776e+19,
          -4.48896433e+19,   3.35313018e+19,  -2.26679366e+19,
           1.33899268e+19,   4.84900561e+19,   4.78318194e+19,
          -2.98264972e+19],
        [  1.54682211e+20,   8.29081201e+19,  -1.81387399e+19,
           1.11458448e+20,  -1.79154259e+20,   3.36984861e+19,
          -9.06741736e+19,  -2.62560098e+19,  -8.58666883e+19,
           1.19715809e+19,  -1.80877043e+19,  -1.23792207e+20,
           3.75016902e+19,  -1.07143136e+20,  -9.50588700e+19,
          -1.10192221e+20,   8.23105290e+19,  -5.56437997e+19,
           3.28687871e+19,   1.19030345e+20,   1.17414514e+20,
          -7.32161935e+19],
        [ -1.79400935e+19,  -9.61570828e+18,   2.10374024e+18,
          -1.29269928e+19,   2.07783731e+19,  -3.90836428e+18,
           1.05164252e+19,   3.04517836e+18,   9.95884740e+18,
          -1.38846913e+18,   2.09781079e+18,   1.43574658e+19,
          -4.34946522e+18,   1.24264901e+19,   1.10249613e+19,
           1.27801355e+19,  -9.54640650e+18,   6.45358918e+18,
          -3.81213010e+18,  -1.38051776e+19,  -1.36177752e+19,
           8.49164026e+18],
        [ -3.14985125e+19,  -1.68828867e+19,   3.69364826e+18,
          -2.26967090e+19,   3.64818560e+19,  -6.86214356e+18,
           1.84643220e+19,   5.34660986e+18,   1.74853571e+19,
          -2.43781352e+18,   3.68326640e+18,   2.52082750e+19,
          -7.63660656e+18,   2.18179597e+19,   1.93571979e+19,
           2.24388628e+19,  -1.67611908e+19,   1.13309599e+19,
          -6.69318367e+18,  -2.42385879e+19,  -2.39095575e+19,
           1.49092854e+19],
        [ -9.83202033e+19,  -5.26986529e+19,   1.15295684e+19,
          -7.08459608e+19,   1.13875259e+20,  -2.14197181e+19,
           5.76349617e+19,   1.66890122e+19,   5.45791733e+19,
          -7.60948613e+18,   1.14970101e+19,   7.86856590e+19,
          -2.38371335e+19,   6.81030092e+19,   6.04219623e+19,
           7.00411103e+19,  -5.23188685e+19,   3.53686818e+19,
          -2.08923436e+19,  -7.56589341e+19,  -7.46318279e+19,
           4.65382087e+19],
        [ -1.12158131e+20,  -6.01156487e+19,   1.31522321e+19,
          -8.08171221e+19,   1.29902508e+20,  -2.44343856e+19,
           6.57467259e+19,   1.90378842e+19,   6.22608607e+19,
          -8.68043879e+18,   1.31151655e+19,   8.97602102e+19,
          -2.71920384e+19,   7.76880808e+19,   6.89259991e+19,
           7.98989924e+19,  -5.96823612e+19,   4.03466270e+19,
          -2.38327605e+19,  -8.63074311e+19,  -8.51357945e+19,
           5.30881631e+19],
        [ -1.88828923e+19,  -1.01210401e+19,   2.21428208e+18,
          -1.36063509e+19,   2.18703407e+19,  -4.11375233e+18,
           1.10690918e+19,   3.20521548e+18,   1.04822182e+19,
          -1.46143341e+18,   2.20805582e+18,   1.51119935e+19,
          -4.57803550e+18,   1.30795466e+19,   1.16043600e+19,
           1.34517776e+19,  -1.00480843e+19,   6.79274976e+18,
          -4.01245867e+18,  -1.45306730e+19,  -1.43334285e+19,
           8.93789682e+18],
        [ -1.23168690e+19,  -6.60171809e+18,   1.44433077e+18,
          -8.87509963e+18,   1.42655081e+19,  -2.68330699e+18,
           7.22011056e+18,   2.09068353e+18,   6.83730152e+18,
          -9.53257501e+17,   1.44026993e+18,   9.85720188e+18,
          -2.98614629e+18,   8.53147523e+18,   7.56924911e+18,
           8.77427364e+18,  -6.55413246e+18,   4.43074895e+18,
          -2.61723550e+18,  -9.47802040e+18,  -9.34935983e+18,
           5.82998233e+18],
        [ -2.85023626e+19,  -1.52769879e+19,   3.34235529e+18,
          -2.05377506e+19,   3.30116595e+19,  -6.20943383e+18,
           1.67079905e+19,   4.83802332e+18,   1.58221206e+19,
          -2.20593187e+18,   3.33293325e+18,   2.28104518e+19,
          -6.91022102e+18,   1.97425839e+19,   1.75159080e+19,
           2.03044279e+19,  -1.51668853e+19,   1.02531371e+19,
          -6.05655231e+18,  -2.19330216e+19,  -2.16352632e+19,
           1.34911163e+19],
        [ -6.44455882e+19,  -3.45421831e+19,   7.55719778e+18,
          -4.64371871e+19,   7.46414346e+19,  -1.40398850e+19,
           3.77777871e+19,   1.09390928e+19,   3.57748278e+19,
          -4.98775371e+18,   7.53589493e+18,   5.15758212e+19,
          -1.56244313e+19,   4.46392644e+19,   3.96045887e+19,
           4.59096351e+19,  -3.42932333e+19,   2.31830005e+19,
          -1.36941945e+19,  -4.95918706e+19,  -4.89186665e+19,
           3.05042272e+19],
        [ -1.59527590e+20,  -8.55052018e+19,   1.87070260e+19,
          -1.14949830e+20,   1.84766197e+20,  -3.47541356e+19,
           9.35145295e+19,   2.70784496e+19,   8.85564327e+19,
          -1.23466031e+19,   1.86542686e+19,   1.27669969e+20,
          -3.86764717e+19,   1.10499287e+20,   9.80365519e+19,
           1.13643934e+20,  -8.48889473e+19,   5.73868149e+19,
          -3.38984348e+19,  -1.22758963e+20,  -1.21092483e+20,
           7.55096909e+19],
        [ -2.23195754e+20,  -1.19630695e+20,   2.61730210e+19,
          -1.60826819e+20,   2.58507206e+20,  -4.86246252e+19,
           1.30836587e+20,   3.78855903e+19,   1.23899686e+20,
          -1.72741833e+19,   2.60992872e+19,   1.78623610e+20,
          -5.41123922e+19,   1.54600084e+20,   1.37163366e+20,
           1.58999742e+20,  -1.18768465e+20,   8.02901512e+19,
          -4.74274198e+19,  -1.71752588e+20,  -1.69421031e+20,
           1.05645928e+20],
        [ -6.41332907e+19,  -3.43748090e+19,   7.52060998e+18,
          -4.62121317e+19,   7.42797132e+19,  -1.39718645e+19,
           3.75947173e+19,   1.08860794e+19,   3.56014611e+19,
          -4.96358525e+18,   7.49938615e+18,   5.13258866e+19,
          -1.55487223e+19,   4.44229213e+19,   3.94126571e+19,
           4.56871365e+19,  -3.41270665e+19,   2.30706498e+19,
          -1.36278564e+19,  -4.93515689e+19,  -4.86816049e+19,
           3.03564123e+19],
        [ -6.03991445e+19,  -3.23733245e+19,   7.08264324e+18,
          -4.35214960e+19,   6.99548354e+19,  -1.31583160e+19,
           3.54057828e+19,   1.02522508e+19,   3.35285886e+19,
          -4.67456858e+18,   7.06273054e+18,   4.83374702e+19,
          -1.46433810e+19,   4.18364591e+19,   3.71178919e+19,
           4.30270720e+19,  -3.21399808e+19,   2.17273922e+19,
          -1.28343220e+19,  -4.64780573e+19,  -4.58471464e+19,
           2.85889086e+19],
        [ -4.41043380e+19,  -2.36394811e+19,   5.17186083e+18,
          -3.17800117e+19,   5.10820282e+19,  -9.60839821e+18,
           2.58538114e+19,   7.48634373e+18,   2.44830590e+19,
          -3.41343901e+18,   5.15732389e+18,   3.52967153e+19,
          -1.06928044e+19,   3.05495926e+19,   2.71040164e+19,
           3.14189748e+19,  -2.34690861e+19,   1.58656488e+19,
          -9.37182455e+18,  -3.39389627e+19,  -3.34782512e+19,
           2.08760366e+19],
        [ -1.73348636e+20,  -9.29131288e+19,   2.03277531e+19,
          -1.24908762e+20,   2.00773834e+20,  -3.77651330e+19,
           1.01616393e+20,   2.94244406e+19,   9.62286937e+19,
          -1.34162853e+19,   2.02704391e+19,   1.38730960e+20,
          -4.20273044e+19,   1.20072632e+20,   1.06530153e+20,
           1.23489729e+20,  -9.22435068e+19,   6.23586487e+19,
          -3.68352939e+19,  -1.33394466e+20,  -1.31583613e+20,
           8.20516511e+19],
        [ -1.26855237e+20,  -6.79931261e+19,   1.48756147e+19,
          -9.14073628e+19,   1.46924841e+20,  -2.76362076e+19,
           7.43621135e+19,   2.15326089e+19,   7.04194772e+19,
          -9.81792043e+18,   1.48337438e+19,   1.01522294e+20,
          -3.07552389e+19,   8.78683273e+19,   7.79580153e+19,
           9.03689258e+19,  -6.75030564e+19,   4.56336224e+19,
          -2.69557476e+19,  -9.76170498e+19,  -9.62919218e+19,
           6.00447729e+19],
        [ -1.06167835e+20,  -5.69048768e+19,   1.24497171e+19,
          -7.65007740e+19,   1.22964522e+20,  -2.31293299e+19,
           6.22352393e+19,   1.80210853e+19,   5.89355489e+19,
          -8.21681946e+18,   1.24146791e+19,   8.49661773e+19,
          -2.57397188e+19,   7.35388533e+19,   6.52447202e+19,
           7.56316786e+19,  -5.64947374e+19,   3.81917489e+19,
          -2.25598235e+19,  -8.16977726e+19,  -8.05887564e+19,
           5.02527462e+19],
        [ -1.29848417e+19,  -6.95973920e+18,   1.52266550e+18,
          -9.35641277e+18,   1.50391541e+19,  -2.82883089e+18,
           7.61167327e+18,   2.20406385e+18,   7.20809891e+18,
          -1.00496055e+18,   1.51837204e+18,   1.03917734e+19,
          -3.14809691e+18,   8.99415402e+18,   7.97974152e+18,
           9.25011819e+18,  -6.90958668e+18,   4.67103366e+18,
          -2.75917728e+18,  -9.99203325e+18,  -9.85639486e+18,
           6.14615517e+18],
        [ -1.91657030e+20,  -1.02726243e+20,   2.24746183e+19,
          -1.38101170e+20,   2.21978801e+20,  -4.17537053e+19,
           1.12348696e+20,   3.25321560e+19,   1.06392017e+20,
          -1.48332436e+19,   2.24113278e+19,   1.53383170e+20,
          -4.64660237e+19,   1.32754299e+20,   1.17781477e+20,
           1.36532262e+20,  -1.01985849e+20,   6.89447364e+19,
          -4.07256639e+19,  -1.47483042e+20,  -1.45480962e+20,
           9.07176046e+19],
        [ -5.16431473e+19,  -2.76802094e+19,   6.05593465e+18,
          -3.72121779e+19,   5.98135186e+19,  -1.12507912e+19,
           3.02730385e+19,   8.76597999e+18,   2.86679585e+19,
          -3.99690855e+18,   6.03888370e+18,   4.13300239e+19,
          -1.25205519e+19,   3.57714349e+19,   3.17369175e+19,
           3.67894260e+19,  -2.74807101e+19,   1.85775692e+19,
          -1.09737830e+19,  -3.97402037e+19,  -3.92007236e+19,
           2.44444097e+19],
        [ -2.75090738e+19,  -1.47445965e+19,   3.22588491e+18,
          -1.98220192e+19,   3.18612219e+19,  -5.99304357e+18,
           1.61257288e+19,   4.66941671e+18,   1.52707269e+19,
          -2.12905274e+18,   3.21679474e+18,   2.20155275e+19,
          -6.66940410e+18,   1.90545592e+19,   1.69054883e+19,
           1.95968245e+19,  -1.46383295e+19,   9.89581904e+18,
          -5.84548860e+18,  -2.11686738e+19,  -2.08812894e+19,
           1.30209615e+19],
        [  4.40458177e+18,   2.36081133e+18,  -5.16510436e+17,
           3.17377764e+18,  -5.10141883e+18,   9.59570046e+17,
          -2.58194982e+18,  -7.47639042e+17,  -2.44505415e+18,
           3.40894919e+17,  -5.15044643e+17,  -3.52498340e+18,
           1.06786635e+18,  -3.05089879e+18,  -2.70680062e+18,
          -3.13772149e+18,   2.34380282e+18,  -1.58445573e+18,
           9.35945569e+17,   3.38939498e+18,   3.34338032e+18,
          -2.08483446e+18],
        [ -2.81106972e+20,  -1.50670530e+20,   3.29640903e+19,
          -2.02555494e+20,   3.25580409e+20,  -6.12410146e+19,
           1.64783938e+20,   4.77155107e+19,   1.56047145e+20,
          -2.17562527e+19,   3.28710520e+19,   2.24969921e+20,
          -6.81526575e+19,   1.94713132e+20,   1.72752252e+20,
           2.00254353e+20,  -1.49584676e+20,   1.01122501e+20,
          -5.97331898e+19,  -2.16316188e+20,  -2.13379626e+20,
           1.33057234e+20]]),
 array([-0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.,
        -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.,
        -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.,
        -0., -0., -0., -0., -0.]),
 ('MT19937', array([1630903701,  734435289,  863130909, 2672639682, 4000147909,
         3068813056, 3696302767,  280953192, 1315359245, 3810876899,
          470883586, 1203618019,  277816173, 1695638856, 2822117295,
         2846711194, 3481403848, 2948219859, 1567443258, 1072755819,
         1433992878, 1069048000, 3752919297, 2501046498,  536266684,
          798647967, 2231834747,   44774053, 3743674643, 2664557481,
         3233389271, 3638904794, 2904958718, 3066150371, 3953821782,
         3542606976, 1133468662,   38086141, 3669017756, 3568766797,
         1273868461,  267436217, 1413159931, 4162749619, 2982674423,
         2316348366, 2500370843, 3804675535,   72940285, 1252882149,
         2720147631, 3917042120, 1680656334, 3987342037, 2328983022,
          111499469, 3727975347, 1535325332,  179877895,  232278970,
         3250419754,  359466498, 2201177757, 3334242759, 1001735198,
         2055629471, 1902685087, 1736204863,  989604204, 3189138477,
          341712107,  622687371, 2527525250, 2730737932, 2164493652,
         2798233555, 3957158774, 3235937957, 1419135210, 2027476397,
          251109190,   73959196, 1819822405, 1651192478, 2983483725,
           13350540, 3144357740, 3546920766, 2381480230, 2323390399,
         2679903852, 1369119890, 1669160424,  966319624, 1738596407,
         3760703705, 3730434347, 3401071618, 2357916048, 3326649402,
         1584165996, 3446339502, 2002388179, 1404871891, 1976552253,
         3998352011, 4262614843, 2959678007, 3705511199, 1554682634,
         2892077590, 3636419762, 3264537537, 3168345569, 4013513641,
         2661803607, 4133601814, 3576117758, 3927047305,  763668635,
         2314398988, 3318911026, 1177963193, 3987857419, 4170351422,
         3530183602, 2647506041, 3069352913, 3728886561, 2388611593,
         1441283741, 1291197852, 2454198437, 1711333834,  356308444,
         2385399984,  639776757, 2440974072, 1783903721, 4236314655,
         1018617167,  998739667, 4168703074,  461044111, 3204859579,
         3679881154, 4123924370, 1829190007,  893224565, 2825176886,
         3436414366, 2703777977, 3019925775, 2044101276, 3346785026,
         2157211579, 3278902701, 1668918558, 1195225066, 3618465280,
         1032592518, 3743296119, 3718954722, 3308609965, 1215947384,
         3769868737, 1814975163,  615979743, 1267328056, 1982817310,
          163957590, 1004677339, 4264593422, 3564129602,  316047747,
          848418513,  335554416, 1252498934,  693681583, 1377599967,
         3201720023, 2453063826,  688890089,   79919477, 1874934019,
          326704024,  956703863,  998072165, 2252303253, 3666275530,
         1575160144, 3636162407, 1236768180,  326412727, 1717722603,
         1767218073, 4225490491,  908331521, 4068613074,   58179979,
          536267140,  297612640,  561875040, 2660512342, 3865526135,
         4233171445, 2744239493, 2517450522, 3054617508, 4281273086,
         2379022188, 2991301665, 2393949823, 3146229535, 4046221636,
          638801108,  515844126, 3725625923, 3421665260, 3861566544,
         2098742324, 3076760732, 2996826265, 3245372340, 1599653242,
          374823664, 2254893927,  152510253, 3323537842, 1202606574,
         4041593615, 3266468720, 2776343670, 3704878337, 2540608782,
         3281646249, 2239994373, 1077877613, 1050754826, 3635790152,
         3263223206, 3087246570, 2103241253, 1727460450, 1157737086,
          569754136, 1398854072, 3343707339,  278271970,  235025081,
         2675984427, 1283557371, 1161942799,  653439562, 2605397951,
         3581715010, 1306027472, 3274002243, 3856626065, 4215030682,
            1578674, 2874312476,  280914048, 3981877310, 4020434747,
         4055775235, 3338814338,  964513526, 3599309434, 1155858388,
         2836245776, 3867339501,  596399924,  744999092,  537469146,
         1489820357,   65112416, 2808700134, 2772039254, 2178707832,
         1392773822,   64325406, 3745212730,   87535029, 3566266897,
         3869351779, 2828391261, 3785542516, 3813814813,  377699785,
         4153344154,   19897396, 3719218371,   95303577, 1787658481,
         3178064649, 1381960184, 3798384149, 3048675504, 3626072947,
         1060480811, 1648018391, 3313149311, 2191458841, 3986763270,
          550931247, 3153896803, 4180436234, 3902958447,  337652413,
           68848502, 1234470964, 1637760424, 2452035334,   21644204,
         1432561622, 1978899798,  391417622, 1310282055, 3525831550,
         4217578918, 2270539556, 3519292933,  740171224,  204991816,
         2796928991, 2721281133, 3265629217,  542358938,  341266933,
         3791569826,  644666650, 2559457278, 1638462113, 2498356806,
         1474879317, 1796441045,  815581286,  215640753, 2075823783,
           78309245, 2449028869, 1002216757, 3480708875, 1359772563,
         3233075433, 2105069393, 2457070553, 4054248086,  215763967,
         3381801211, 3812814953, 1497613656, 1368437084,  608490263,
          858842287, 3251429976, 2544325776, 3320009961, 2033526025,
         3213748389, 2018868718,  764662185, 2430699226, 1310500524,
         2522796606, 2680143744, 2868507784,  200666993, 3443689940,
         3840208271, 3589586964, 3726118048, 1167862297,  334565705,
         1857023276, 3189209219,  238132907,  399620050, 3959386798,
          954657402, 1287266712, 3213062197,  100707319, 3287185421,
         3004020865,  489916966, 1546304538,  117669096, 3631659269,
         3783522632, 1117917655, 1033694349, 1441259041, 1729903852,
          120744495,  596075505,  247996148,  846853885, 3811680002,
          372119759, 2212492594, 4072536141,  779822500, 1751309309,
         1281370141, 3680704847,  358332058, 1729359622, 3331549465,
         2759771122, 3120256068, 1376091265, 2440631472, 1489811496,
          175251682, 2751819686, 1553248732, 3664115547, 2917889612,
         1523971632, 3590285076, 1497742598,  910040044, 1565922302,
         2972710653, 2472642576,  520253603,   49187214, 1116377874,
         3254108820, 3509199912,  472530979, 1284506065, 2851494061,
         2614476333, 2188952911,  233643185,  466441689, 1338975337,
         2084335152, 2251048151, 3954635243, 3678791708, 3554465985,
         2567517933, 3228336744, 4075066508,  528695837, 2079298156,
         3411100617, 3267429378,   50271781,  853339322, 4196541480,
          896619798,  374209896, 3454889034,  470798724, 1268710530,
         3913986547, 2181918891, 3936829963,  870198863, 3108718179,
         1425700986, 3342796197, 1260184571, 3585104955, 3019523884,
         3213531532, 2316975250,  538928034,  478295483, 3822901880,
         1632967998, 1735383391, 1186687306, 2987049951, 3451281055,
         3821098622,  639679035, 3019602217, 3440162113, 2094334114,
         3142782355,  282597108, 2719279287, 3071931535, 1868704966,
          546340804, 3094001015, 2870928880,  570317058, 3390529108,
         2612781791, 3075309620,  345062891, 1836559561,  777866916,
         3322056934, 1647500168, 1095150001,  712828750, 1627558709,
         2487067713, 2426907216,  469356255, 1979703739, 1163802877,
         3698976239, 4218891595, 3215048265, 3888936943, 1658058713,
         2602415677, 4044227159, 2126240151,  407852888,   90794824,
          153833512, 3523616664, 4092720094, 3601590731, 3324940734,
          462543186, 1410469225,  465122639, 1371708361, 3306970707,
         2785101453, 4084912208, 3944711395, 3885109521, 2891080244,
         2167768244, 1614239780, 2073137890,  913348626, 3784494492,
         2788844574, 2909100324, 4213037432, 3021071288, 2168335529,
          653241193, 1667075457, 1076602855, 2569343302, 4226607349,
         1434486462, 2926545886, 2435388857,  352568915,  440762458,
         2786769735, 3025897544,  219809551, 3587603449, 2249417024,
         2793639275,  647504814, 3417643873, 2478310135,  266842444,
         3492935874, 2980722880, 3159667994, 3858322859, 4128617232,
         3572112913, 2259330649, 1833784790, 1523086773, 3341670580,
           94362336, 3619021178,  188002110, 3391230067, 3362672769,
         2360866919, 3896238894,  123877897,  120313818, 3705633497,
         4134023784, 3215278906, 2383348273, 1325755262,  574410188,
         1956205862, 1211731957, 2252853264, 2896861474, 1832453743,
         1212930728,  971069738, 2563364594, 1476914575,  540108914,
         1539160392, 2415014267,  554281102, 2792026365, 3362211935,
          488169786,  279445362, 2631333189, 2496411852,  414043287,
         2714439032, 3815799207,  247795573,  882683919, 2385146629,
         1483632425, 1031255289,  268890582, 1576438389, 2908862862,
         1386268361,  829859201, 1273069820, 3390381540], dtype=uint32), 287, 0, 0.0))

In [131]:
sampler.flatchain.shape


Out[131]:
(4400, 22)

In [132]:
normed_chain = np.array([norm(s) for s in sampler.flatchain])

In [133]:
normed_chain


Out[133]:
array([[-0.15196116,  0.53463089,  0.05975893, ...,  0.20109551,
         0.19115583, -0.08688488],
       [-0.13607448,  0.50745558,  0.04680355, ...,  0.18634484,
         0.18072572, -0.07982799],
       [-0.13607448,  0.50745558,  0.04680355, ...,  0.18634484,
         0.18072572, -0.07982799],
       ..., 
       [-0.35751461, -0.19162426,  0.04192417, ..., -0.27511308,
        -0.27137828,  0.16922351],
       [-0.35751461, -0.19162426,  0.04192417, ..., -0.27511308,
        -0.27137828,  0.16922351],
       [-0.35751459, -0.19162425,  0.04192405, ..., -0.27511304,
        -0.27137829,  0.16922349]])

In [128]:
import triangle
reload(triangle)


Out[128]:
<module 'triangle' from '/Users/joshuafass/anaconda/envs/py27/lib/python2.7/site-packages/triangle.pyc'>

In [135]:
samples = np.abs(normed_chain[1000:,:10])
triangle.corner(samples,truths=np.zeros(samples.shape[1]))


Out[135]:

In [162]:
samples = np.abs(normed_chain[len(normed_chain)/2:,:5])
triangle.corner(samples,truths=weights[:samples.shape[1]])


Out[162]:

In [97]:
objs = [objective(s) for s in normed_chain[-500:]]

In [100]:
np.argmin(objs)


Out[100]:
497

In [101]:
plt.plot(objs)


Out[101]:
[<matplotlib.lines.Line2D at 0x1b9083150>]

In [103]:
plt.scatter(normed_chain[-1],norm(weights))


Out[103]:
<matplotlib.collections.PathCollection at 0x1b5158890>

In [157]:
plot_diffs(np.abs(dev_t.dot(norm(np.ones(dev_t.shape[-1])))/n_atoms),label='Unweighted')
plot_diffs(dev_t.dot(norm(weights))/n_atoms,label='Weighted (deterministic)')
for i in range(1,2):
    plot_diffs(np.abs(dev_t.dot(norm(normed_chain[-i*100]))/n_atoms),label='Weighted (MCMC)')
#plot_diffs(diffs,label='Unweighted')
plt.ylim(0,0.03)
plt.legend(loc='best')


Out[157]:
<matplotlib.legend.Legend at 0x22c40e3d0>

In [161]:
plt.plot(np.abs(normed_chain[-1]),label='MCMC')
plt.plot(weights,label='deterministic')
plt.legend()


Out[161]:
<matplotlib.legend.Legend at 0x2002a7dd0>

In [151]:
plt.scatter(np.abs(normed_chain[-1]),weights)


Out[151]:
<matplotlib.collections.PathCollection at 0x1fec02b50>

In [163]:
t


Out[163]:
<mdtraj.Trajectory with 9999 frames, 22 atoms, 3 residues, without unitcells at 0x106810ad0>

In [169]:
from msmbuilder import featurizer

In [221]:
def compute_dev_mat(t,n=1000):
    stride = int(t.n_frames / n)
    
    dev_mat = np.zeros((n,n,t.n_atoms))
    rpf = featurizer.RawPositionsFeaturizer()
    rpft = np.array(rpf.fit_transform(t.center_coordinates())).reshape((len(t),t.n_atoms,3))
    
    for i in range(n):
        for j in range(i):
            dev_mat[i,j] = weighted_rmsd.compute_atomwise_deviation_xyz(rpft[i*stride],rpft[j*stride])
        if i%100==0:
            print(i)
    return dev_mat

In [220]:
fs_t.n_atoms


Out[220]:
264

In [215]:
dev_mat = compute_dev_mat(t)


0
100
200
300
400
500
600
700
800
900

In [222]:
dev_mat_fs = compute_dev_mat(fs_t)


0
100
200
300
400
500
600
700
800
900

In [198]:
def distmat_plot(dist_mat_raw):
    dist_mat = dist_mat_raw + dist_mat_raw.T
    plt.imshow(dist_mat,interpolation='none',cmap='Blues')

In [205]:
distmat_plot(dev_mat.dot(ones))
plt.title('Unweighted')
plt.savefig('ala_unweighted.pdf')

plt.figure()
distmat_plot(dev_mat.dot(weights_10))
plt.title(r'Weighted (deterministic, $\tau=10$)')
plt.savefig('ala_weighted_det_10.pdf')

plt.figure()
distmat_plot(dev_mat.dot(weights_100))
plt.title(r'Weighted (deterministic, $\tau=100$)')
plt.savefig('ala_weighted_det_100.pdf')

plt.figure()
distmat_plot(np.abs(dev_mat.dot(np.abs(normed_chain[-1000]))))
plt.title('Weighted (MCMC)')
plt.savefig('ala_weighted_mcmc.pdf')



In [213]:
weights_fs = compute_kinetic_weights(fs_t)

In [ ]:
def objective(weights):
    normalized_weights = norm(weights)
    return np.sum(np.dot(dev_t,normalized_weights)**2)*np.min(weights >= 0)

In [223]:
distmat_plot(dev_mat_fs.dot(np.ones(len(weights_fs))))
plt.title('Unweighted')
plt.savefig('fs_unweighted.pdf')

plt.figure()
distmat_plot(dev_mat_fs.dot(weights_fs))
plt.title(r'Weighted (deterministic, $\tau=10$)')
plt.savefig('fs_weighted_det.pdf')



In [209]:
sum(sampler.flatlnprobability==0) /


Out[209]:
4400

In [177]:
dev_mat.shape


Out[177]:
(1000, 1000, 22)

In [230]:
dev_t_fs = windowed_deviations(fs_t)


0
1000
2000
3000
4000
5000
6000
7000
8000
9000

In [231]:
dev_t_fs.shape


Out[231]:
(9909, 10, 264)

In [234]:
dev_t_fs.dot(weights_fs).shape


Out[234]:
(9909, 10)

In [238]:
plt.plot(np.mean(dev_t_fs.dot(weights_fs),1))


Out[238]:
[<matplotlib.lines.Line2D at 0x1502b3ed0>]

In [777]:
def sgd(objective,dataset,init_point,batch_size=20,n_iter=100,step_size=0.01,seed=0):
    ''' objective takes in a parameter vector and an array of data'''
    np.random.seed(seed)
    testpoints = []
    testpoints = np.zeros((n_iter,len(init_point)))
    testpoints[0] = init_point
    #testpoints.append(init_point)
    #shuffled = np.array(dataset)
    #np.random.shuffle(shuffled)
    accept_frac = 1.0*batch_size/dataset.shape[0]
    ind=0
    for i in range(1,n_iter):
        max_ind = ind+batch_size
        if max_ind>=len(dataset):
            ind = max_ind % len(dataset)
            max_ind = ind+batch_size
        
        subset = dataset[ind:max_ind]
        ind = (ind + batch_size)
        #else:
        #    #new_ind = (max_ind-len(dataset))
        #    #subset = np.vstack([dataset[ind:],dataset[:new_ind]])
        #    ind = 
        #    subset = dataset[ind:batch_size]
        #    ind = batch_size
        #subset = dataset[np.random.rand(len(dataset))<accept_frac]
        obj_grad = grad(lambda p:objective(p,subset))
        raw_grad = obj_grad(testpoints[i-1])
        gradient = np.nan_to_num(raw_grad)
        #gradient = np.nan_to_num(obj_grad(testpoints[-1]))
        #print(gradient,raw_grad)
        testpoints[i] = testpoints[i-1] - gradient*step_size
        #testpoints.append(testpoints[-1] - gradient*step_size)
    return np.array(testpoints)

In [511]:
dev_t_fs.shape


Out[511]:
(9909, 10, 264)

In [512]:
9909/9.0


Out[512]:
1101.0

In [1088]:
def batch_objective(weights,batch,penalty_param=100):
    ''' want to minimize this objective'''
    #normalized_weights = norm(np.abs(weights))
    #return np.sum(np.abs(np.dot(batch,normalized_weights)))#-np.sum(np.abs(weights))*penalty_param
    l2_norm = np.sqrt(np.sum(weights**2))
    if l2_norm < 1:
        factor_off = 1.0/l2_norm
    else:
        factor_off=l2_norm
    #factor_off=np.max(l2_norm,1.0/l2_norm)
    functional_obj = np.sum(np.abs(np.dot(batch,weights)))
    nonneg_penalty = np.sum(penalty_param*(np.abs(weights[weights<0])))
    norm_penalty = penalty_param*(np.exp(factor_off))
    return functional_obj#+nonneg_penalty+norm_penalty

def triplet_batch_objective_simple(weights,batch,tau_1=0,tau_2=1):
    loss = 0
    for i in range(len(batch)):
        close = np.dot(dev_t[i][tau_1],weights)
        far = np.dot(dev_t[i][tau_2],weights)
        #contribution = (close-far)
        loss += contribution
        #print(close,far)
        #print(contribution)
    return loss / len(batch)

def triplet_batch_objective(weights,batch,tau_1=0,tau_2=1,tau_3=3):
    loss = 0
    for i in range(len(batch)):
        close = np.dot(dev_t[i][tau_1],weights)
        far = np.dot(dev_t[i][tau_2],weights)
        far2 = np.dot(dev_t[i][tau_3],weights)
        #contribution = np.exp(2*(close - far)/(close+far))
        #contribution = close/far
        #contribution = np.exp(close-far)
        #contribution = (close-far)
        #contribution = far-close
        #contribution=np.exp(10*(close-far))
        contribution = (((close-far)>0)*2-1)*np.exp(np.abs(5*(close-far)))
        contribution = contribution+ (((far-far2)>0)*2-1)*np.exp(np.abs(5*(far-far2)))
        #contribution = ((close-far)>0)*np.exp(np.abs(10*(close-far)))
        loss += contribution
        #print(close,far)
        #print(contribution)
    return loss / len(batch)

In [1090]:
%timeit triplet_batch_objective(weights,dev_t[:10])


1000 loops, best of 3: 307 µs per loop

In [1069]:
np.exp(-100),np.exp(100)


Out[1069]:
(3.7200759760208361e-44, 2.6881171418161356e+43)

In [1070]:
tau_1=0
tau_2=1
triplet_batch_objective(np.ones(22),dev_t,tau_1,tau_2),triplet_batch_objective(weights,dev_t,tau_1,tau_2)


Out[1070]:
(-14.252122909283733, -0.62030407272727617)

In [1091]:
raw_weights = sgd(triplet_batch_objective,dev_t,weights,n_iter=2000,step_size=0.001,batch_size=10)

In [1092]:
plt.plot(raw_weights);
plt.figure()
normed_weights = np.array([norm(s) for s in raw_weights])
plt.plot(normed_weights);



In [1093]:
plt.plot(normed_weights[1])
plt.plot(normed_weights[len(normed_weights)/2])
plt.plot(normed_weights[-1])
plt.plot(norm(norm(weights)),label='deterministic')
plt.plot(norm(1/norm(weights)),label='deterministic_inv')
plt.legend(loc='best')


Out[1093]:
<matplotlib.legend.Legend at 0x2bf3ebd10>

In [1094]:
distmat_plot(dev_mat.dot(norm(1/normed_weights[-1])))
plt.title('Weighted (triplet, inverse)')
plt.savefig('ala_weighted_triplet_inv.pdf')

plt.figure()
distmat_plot(dev_mat.dot(norm(normed_weights[-1])))
plt.title('Weighted (triplet)')
plt.savefig('ala_weighted_triplet.pdf')



In [1095]:
from sklearn.cluster import SpectralClustering

distmat = dev_mat.dot(norm(1/normed_weights[-1]))
distmat = distmat + distmat.T

sc = SpectralClustering(affinity='precomputed')
sc.fit(distmat)


Out[1095]:
SpectralClustering(affinity='precomputed', assign_labels='kmeans', coef0=1,
          degree=3, eigen_solver=None, eigen_tol=0.0, gamma=1.0,
          kernel_params=None, n_clusters=8, n_init=10, n_neighbors=10,
          random_state=None)

In [1096]:
eig_vals,eig_vecs = np.linalg.eigh(distmat)
plt.hist(eig_vals,bins=100);



In [1097]:
plt.plot(eig_vals[-10:])


Out[1097]:
[<matplotlib.lines.Line2D at 0x2ae02dd10>]

In [1109]:
dist_mat_uw = dev_mat.dot(np.ones(22))
dist_mat_uw = dist_mat_uw + dist_mat_uw.T
eig_vals_uw,_ = np.linalg.eigh(dist_mat_uw)
plt.plot(norm(eig_vals[-50:-1][::-1]),label='weighted')
plt.plot(norm(eig_vals_uw[-50:-1][::-1]),label='unweighted')
plt.legend(loc='best')
plt.ylabel('Normalized eigenvalues')
plt.xlabel('Eigenvalue index')


Out[1109]:
<matplotlib.text.Text at 0x2bafce9d0>

In [1110]:
triplet_batch_objective(normed_weights[-1],dev_t),triplet_batch_objective(norm(1/normed_weights[-1]),dev_t)


Out[1110]:
(-1.2174985944313632, -0.75468796465818089)

In [764]:
dev_t.shape


Out[764]:
(9908, 10, 22)

In [538]:
dev_t.shape[0]/10


Out[538]:
990

In [542]:
raw_points = sgd(batch_objective,dev_t,norm(weights),batch_size=10,n_iter=1000,step_size=0.001,seed=0)
normed_points = np.array([norm(s) for s in raw_points])
plt.plot(np.abs(normed_points));
plt.figure()
plt.plot(raw_points);



In [545]:
raw_points = sgd(batch_objective,dev_t_fs,norm(weights_fs),batch_size=10,n_iter=1000,step_size=0.0001,seed=0)
normed_points = np.array([norm(s) for s in raw_points])
plt.plot(np.abs(normed_points));
plt.figure()
plt.plot(raw_points);



In [519]:
raw_points[:100,1]


Out[519]:
array([ 0.31414733,  0.30332838,  0.29349807,  0.28345948,  0.27782058,
        0.2749867 ,  0.2733865 ,  0.26737435,  0.26628604,  0.26243024,
        0.25947423,  0.25947423,  0.25513815,  0.25199986,  0.24987689,
        0.24987689,  0.24643087,  0.24643087,  0.24643087,  0.24643087,
        0.24643087,  0.24293042,  0.2380981 ,  0.23747412,  0.23200193,
        0.2326575 ,  0.22366119,  0.22366119,  0.22876661,  0.22154619,
        0.2257132 ,  0.2138445 ,  0.222414  ,  0.21196909,  0.21196909,
        0.21196909,  0.22004061,  0.20815076,  0.20815076,  0.21863401,
        0.2087079 ,  0.21681917,  0.21681917,  0.20514813,  0.21344224,
        0.20468115,  0.21142551,  0.19958512,  0.21087144,  0.21087144,
        0.19926102,  0.20619425,  0.20041358,  0.20041358,  0.20326625,
        0.19177422,  0.20028841,  0.18962811,  0.18962811,  0.20030296,
        0.19043069,  0.20036511,  0.19017643,  0.19644176,  0.18686938,
        0.19611237,  0.19611237,  0.19611237,  0.19611237,  0.18544529,
        0.19657272,  0.18489183,  0.19378337,  0.19378337,  0.19378337,
        0.18327227,  0.19173896,  0.19173896,  0.18078421,  0.19013494,
        0.19013494,  0.18028395,  0.18918584,  0.17923712,  0.17923712,
        0.1877851 ,  0.17720661,  0.18692798,  0.17482374,  0.18477889,
        0.17474097,  0.18545448,  0.17488109,  0.18408053,  0.17296767,
        0.18224677,  0.16959202,  0.18114   ,  0.16894235,  0.18080792])

In [554]:
results = [batch_objective(b,dev_t_fs) for b in normed_points[::(len(normed_points)/100)]]

In [555]:
plt.plot(results)
plt.hlines(batch_objective(norm(weights_fs),dev_t_fs),0,len(results))


Out[555]:
<matplotlib.collections.LineCollection at 0x2a2b42890>

In [560]:
#for i in range(100):
#    plt.plot(np.abs(normed_points[i*10]))
plt.plot(norm(weights_fs),label='deterministic')
plt.plot(np.abs(normed_points[100]),label='SGD')
plt.legend(loc='best')


Out[560]:
<matplotlib.legend.Legend at 0x28cb46150>

In [339]:
normed_points[1]


Out[339]:
array([ 0.09262412,  0.083905  ,  0.15303478,  0.18383372,  0.18944637,
        0.15923845,  0.08234395,  0.12010008,  0.06150959,  0.05872907,
        0.11158566,  0.12441057,  0.1442361 ,  0.14246173,  0.04589156,
        0.06671119,  0.04316535,  0.06148455,  0.05025831,  0.08896815,
        0.09112393,  0.11873766,  0.12467369,  0.11665577,  0.0116853 ,
        0.04824015, -0.02087788, -0.01734914,  0.01202569,  0.03310767,
        0.07431809,  0.1280293 ,  0.09774977,  0.10695625,  0.00746239,
        0.04993101, -0.01391455, -0.01316326,  0.00368581,  0.02962098,
        0.03949316,  0.07508904,  0.06545504,  0.06961499, -0.01246463,
        0.02332599, -0.03021613, -0.0061458 , -0.02718408, -0.0082692 ,
        0.0010858 ,  0.03862725,  0.047269  ,  0.04995896, -0.04581255,
       -0.05739671, -0.02483093, -0.00207665, -0.02111347, -0.01448523,
        0.00784799,  0.03094038,  0.04133141,  0.05483837, -0.03265709,
       -0.02083627, -0.02920801, -0.01464491, -0.01758712, -0.00268316,
        0.01596624,  0.06459417,  0.04813297,  0.04221584, -0.03504525,
       -0.01915489, -0.05146779, -0.04410759, -0.0496097 , -0.02980237,
       -0.03119635,  0.01003764,  0.02306434,  0.03034059, -0.06085426,
       -0.03361767, -0.07569833, -0.07116752, -0.06012815, -0.0401182 ,
       -0.03990953, -0.01438255, -0.01356045, -0.00430028,  0.03632213,
        0.04128349,  0.01647587,  0.05277396,  0.04915587,  0.03644413,
        0.06330783,  0.08284904,  0.12968946,  0.14244026,  0.16956029,
        0.11440506,  0.10942604,  0.16112755, -0.05562018, -0.02300453,
       -0.06802936, -0.06325471, -0.05296476, -0.03367583, -0.02810572,
        0.01895139,  0.01501081,  0.02217634, -0.05517253, -0.02376363,
       -0.06339174, -0.05127067, -0.05267087, -0.03837038, -0.02862667,
        0.01936633,  0.02000767,  0.02638699, -0.05416424, -0.02290243,
       -0.05661396, -0.03911232, -0.04237775, -0.01762503, -0.01496253,
        0.02637243,  0.03898788,  0.03327484, -0.05528708, -0.0513843 ,
       -0.03891265, -0.02080185, -0.02555311, -0.01156179,  0.00371672,
        0.04783349,  0.05016404,  0.02784407, -0.03812756, -0.01820265,
       -0.05189163, -0.04526953, -0.04624407, -0.02947586, -0.02238157,
        0.00946418,  0.00735451,  0.00704594,  0.04721694,  0.04482327,
        0.03020711,  0.06406489,  0.07346526,  0.04267531,  0.07940178,
        0.05753726,  0.09823697,  0.12477035,  0.11517717,  0.07675545,
        0.09424711,  0.10554558, -0.05269389, -0.0257805 , -0.0689909 ,
       -0.0651961 , -0.05841118, -0.03614062, -0.03497846,  0.01784938,
        0.02511484,  0.01669949, -0.06020436, -0.03745976, -0.05231702,
       -0.03812394, -0.0347215 , -0.01269001, -0.00239612,  0.03242807,
        0.05318981,  0.04655324, -0.0482486 , -0.03255557, -0.0509375 ,
       -0.04126221, -0.04768372, -0.03565864, -0.02062403,  0.01274364,
        0.02701797,  0.02633158, -0.0504922 , -0.03703674, -0.0434392 ,
       -0.03596047, -0.02313478, -0.01151866,  0.01134959,  0.04126616,
        0.0523898 ,  0.05912514, -0.0294864 , -0.00530842, -0.04653215,
       -0.04668021, -0.034659  , -0.0168606 , -0.01807919,  0.00264143,
       -0.00716474,  0.0008519 ,  0.01205143,  0.01825566,  0.02717176,
        0.05491529,  0.03719866,  0.0422826 ,  0.06465702,  0.06910888,
        0.09047794,  0.09180014,  0.1227335 ,  0.10697872,  0.10665104,
        0.1487687 , -0.02659878, -0.00135817, -0.0246136 , -0.02480806,
       -0.00240736,  0.01601518,  0.01542427,  0.05103953,  0.03178504,
        0.06625403,  0.00269788,  0.03413005, -0.00340027, -0.00523402,
        0.01529173,  0.03203057,  0.02966765,  0.06943664,  0.05653103,
        0.0535181 ,  0.04055681,  0.06140364,  0.06246918,  0.06368046,
        0.09925732,  0.11563953,  0.11926776,  0.13192084])

In [241]:
a = np.ones(10)
a[5:15]


Out[241]:
array([ 1.,  1.,  1.,  1.,  1.])

In [ ]:
def

In [ ]:
objective

In [ ]:
plt.plot(diffs.mean(0)[1:],label='Mean')
plt.plot(np.median(diffs[:,1:],0),label='Median')
#plt.plot(np.min(diffs[:,1:],0),label='Min')
#plt.plot(np.max(diffs[:,1:],0),label='Max')
plt.xlabel(r'Time lag ($\tau$)')
plt.ylabel('RMSD')
plt.legend(loc='best')
plt.title('Fs Peptide')

In [178]:
dev_mat.dot(weights).shape


Out[178]:
(1000, 1000)

In [297]:
plt.imshow(dev_mat_fs.dot(weights_fs))


Out[297]:
<matplotlib.image.AxesImage at 0x29f0f0f90>

In [298]:
distmat_plot(dev_mat_fs.dot(weights_fs))



In [546]:
distmat_plot(dev_mat_fs.dot(np.abs(normed_points[-1])))



In [547]:
distmat_plot(dev_mat_fs.dot(np.ones(len(weights_fs))))



In [561]:
distmat_plot(dev_mat_fs.dot(norm(np.ones(len(weights_fs)))) / (1+dev_mat_fs.dot(np.abs(normed_points[100]))))



In [ ]:
# other idea: compute a gradient based on kinetically near/far triplets

In [568]:
batch = dev_t[:100]

In [564]:
def penalty(metric,center,close,far,inf=10000):
    ''' compute d_metric(center,close) and d_metric(center,far), and penalize 
    (d_metric(center,close)/d_metric(center,far)) -- i.e. we want the distance from the center point
    to the "close" point to be small relative to the distance from the center point to the "far" point'''
    d_far = metric(center,far)
    d_close = metric(center,close)
    return d_close/d_far

In [570]:
tau_1=1
tau_2=10
triplets = np.array([(batch[i],batch[i+tau_1],batch[i+tau_2]) for i in range(len(batch)-tau_2)])
triplets.shape


Out[570]:
(90, 3, 10, 22)

In [572]:
triplets[0].shape


Out[572]:
(3, 10, 22)

In [575]:
from msmbuilder.featurizer import DihedralFeaturizer
dihedrals = DihedralFeaturizer().transform([t])[0]
dihedrals.shape


Out[575]:
(9999, 4)

In [612]:
dihedrals_fs = DihedralFeaturizer().transform([fs_t])[0]

In [577]:
from scipy.spatial.distance import euclidean

In [595]:
def near_far_triplet_loss(metric,batch,tau_1=1,tau_2=10):
    ''' batch is a numpy array of time-ordered observations'''
    #triplets = np.array([(batch[i],batch[i+tau_1],batch[i+tau_2]) for i in range(len(batch)-tau_2)])
    cost=0
    n_triplets = len(batch)-tau_2
    for i in range(n_triplets):
        cost+=penalty(metric,batch[i],batch[i+tau_1],batch[i+tau_2])
    return cost / n_triplets

In [596]:
near_far_triplet_cost(euclidean,dihedrals)


Out[596]:
0.8897582722030105

In [597]:
def triplet_wrmsd_loss(weights,dataset,tau_1=1,tau_2=10):
    ''' given a weight vector and a dataset, compute the
    '''
    metric = lambda x,y:wRMSD(x,y,weights)
    return near_far_triplet_loss(metric,dataset,tau_1,tau_2)

In [598]:
def generate_triplet_weighted_metric_loss(weighted_metric,tau_1=1,tau_2=10):
    return lambda weights,dataset:near_far_triplet_loss(lambda x,y:weighted_metric(x,y,weights),dataset,tau_1,tau_2)

In [601]:
def weighted_dot_prod(x,y,weights):
    return np.dot(x*weights,y)

In [640]:
from scipy.spatial.distance import mahalanobis
mahalanobis(x,y,np.diag(np.ones(len(x))))


Out[640]:
0.42989719828259687

In [642]:
euclidean(x,y)


Out[642]:
0.4298971891403198

In [603]:
generate_triplet_weighted_metric_loss(weighted_dot_prod)(np.ones(4),dihedrals)


Out[603]:
0.38391835095146271

In [643]:
generate_triplet_weighted_metric_loss(mahalanobis)(np.diag(np.ones(len(x))),dihedrals)


Out[643]:
0.88975827169429644

In [ ]:
grad(

In [586]:
x = dihedrals[0]
y = dihedrals[1]
x.T.dot(y)


Out[586]:
1.9075942

In [590]:
(x*np.ones(len(x))).dot(y)


Out[590]:
1.9075942020746846

In [634]:
raw_points = sgd(generate_triplet_weighted_metric_loss(weighted_dot_prod),dihedrals,np.ones(4),batch_size=100,step_size=0.001)

In [637]:
plt.plot(raw_points);
normed_points = np.array([norm(s) for s in raw_points])
plt.figure()
plt.plot(normed_points);



In [ ]:


In [631]:
raw_points = sgd(generate_triplet_weighted_metric_loss(weighted_dot_prod),dihedrals_fs,np.ones(dihedrals_fs.shape[1]),n_iter=1000,batch_size=30,step_size=0.1)

In [632]:
normed_points = np.array([norm(s) for s in raw_points])

In [633]:
plt.plot(raw_points);
plt.figure()
plt.plot(normed_points);



In [747]:
def np_mahalanobis(x,y,A_vec):
    A = np.diag(A_vec)
    return np.sqrt(np.dot(np.dot(x,A),np.dot(y,A)))

In [749]:
raw_points = sgd(objective=generate_triplet_weighted_metric_loss(np_mahalanobis,tau_1=1,tau_2=5),
                 dataset=dihedrals,init_point=np.random.rand(4),batch_size=20,step_size=0.1,n_iter=1000)

In [653]:
def projected_distance(x,y,A):
    return np.sqrt(np.dot(np.dot(x,A),np.dot(y,A)))

In [667]:
def projected_distance_vec(x,y,A_vec):
    A = np.reshape(A_vec,(len(x),len(A_vec)/len(x)))
    return projected_distance(x,y,A)

In [734]:
raw_points = sgd(objective=generate_triplet_weighted_metric_loss(projected_distance_vec,tau_1=1,tau_2=5),
                 dataset=dihedrals,init_point=np.random.rand(8),batch_size=20,step_size=0.1,n_iter=1000)

In [735]:
np.dot(dihedrals,np.reshape(raw_points[-1],(4,2))).shape


Out[735]:
(9999, 2)

In [750]:
plt.plot(raw_points);



In [752]:
normed_points = np.array([norm(s) for s in raw_points])
plt.plot(normed_points);



In [760]:
reload(weighted_rmsd)
from weighted_rmsd import wRMSD_xyz

In [ ]:


In [761]:
raw_points = sgd(objective=generate_triplet_weighted_metric_loss(wRMSD_xyz,tau_1=1,tau_2=5),
                 dataset=t,init_point=np.random.rand(4),batch_size=20,step_size=0.1,n_iter=10)


---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-761-d9b334b8d703> in <module>()
      1 raw_points = sgd(objective=generate_triplet_weighted_metric_loss(wRMSD_xyz,tau_1=1,tau_2=5),
----> 2                  dataset=t,init_point=np.random.rand(4),batch_size=20,step_size=0.1,n_iter=10)

<ipython-input-683-4fa7f71edcff> in sgd(objective, dataset, init_point, batch_size, n_iter, step_size, seed)
      6     testpoints[0] = init_point
      7     #testpoints.append(init_point)
----> 8     shuffled = np.array(dataset)
      9     np.random.shuffle(shuffled)
     10     accept_frac = 1.0*batch_size/dataset.shape[0]

/Users/joshuafass/anaconda/envs/py27/lib/python2.7/site-packages/autograd/numpy/numpy_wrapper.pyc in array(A, *args, **kwargs)
     35     else:
     36         raw_array = np.array(A, *args, **kwargs)
---> 37         return wrap_if_nodes_inside(raw_array)
     38 
     39 def wrap_if_nodes_inside(raw_array, slow_op_name=None):

/Users/joshuafass/anaconda/envs/py27/lib/python2.7/site-packages/autograd/numpy/numpy_wrapper.pyc in wrap_if_nodes_inside(raw_array, slow_op_name)
     42             warnings.warn("{0} is slow for array inputs. "
     43                           "np.concatenate() is faster.".format(slow_op_name))
---> 44         return array_from_args(raw_array.shape, *raw_array.ravel())
     45     else:
     46         return raw_array

/Users/joshuafass/anaconda/envs/py27/lib/python2.7/site-packages/autograd/core.pyc in __call__(self, *args, **kwargs)
    105                         tapes.add(tape)
    106 
--> 107         result = self.fun(*argvals, **kwargs)
    108         if result is NotImplemented: return result
    109         if ops:

/Users/joshuafass/anaconda/envs/py27/lib/python2.7/site-packages/autograd/numpy/numpy_wrapper.pyc in array_from_args(front_shape, *args)
     49 def array_from_args(front_shape, *args):
     50     new_array = np.array(args)
---> 51     return new_array.reshape(front_shape + new_array.shape[1:])
     52 
     53 def array_from_args_gradmaker(argnum, ans, front_shape, *args):

ValueError: sequence too large; must be smaller than 32

In [ ]:
wRMSD(

In [753]:
grad(lambda weights:generate_triplet_weighted_metric_loss(projected_distance_vec,tau_1=1,tau_2=5)(weights,dihedrals))(np.ones(8))


Out[753]:
array([ nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan])

In [745]:
projected_distance_vec(dihedrals[10],dihedrals[10],np.ones(8))


Out[745]:
1.8950667884464911

In [ ]: