In [1]:
import numpy as np
import numpy.random as npr
import matplotlib.pyplot as plt
%matplotlib inline
from autograd import grad
import autograd.numpy as np

from sklearn.datasets import load_digits
data = load_digits()
X,Y = load_digits().data,load_digits().target
X.shape


Out[1]:
(1797, 64)

In [2]:
# in a triplet network, the loss is defined on triplets of observations

In [3]:
def triplet_loss_paper(distance_close,distance_far):
    ''' loss function given two distances from the triplet (a,b,c)
    where distance_close=d(a,b), the distance between two points that
    should be close, and distance_far=d(a,c), the 
    distance between two points that should be far.
    
    see page 3 of: http://arxiv.org/abs/1412.6622'''
    exp_close = np.exp(distance_close)
    exp_far = np.exp(distance_close)
    
    
    d_plus = exp_close / (exp_close + exp_far)
    d_minus = exp_far / (exp_close + exp_far)
    
    return (d_plus-d_minus)**2

def modified_triplet_loss(distance_close,distance_far):
    return distance_far - distance_close

def zero_one_triplet_loss(distance_close,distance_far):
    return 1.0*(distance_far < distance_close)

def distance(x,y):
    return np.sqrt(np.sum((x-y)**2))
                 
def triplet_objective(transformation,x,x_close,x_far,
                      triplet_loss=modified_triplet_loss):
    '''for a metric embedding, we can transform each point separately'''
    t_a,t_b,t_c = [transformation(point) for point in (x,x_close,x_far)]
    
    return triplet_loss(distance(t_a,t_b),distance(t_a,t_c))

def generate_triplet(X,Y):
    x_ind = np.random.randint(len(X))
    x,y = X[x_ind],Y[x_ind]
    x_close_ind = np.random.randint(len(X[Y==y]))
    x_far_ind = np.random.randint(len(X[Y!=y]))
    x_close = X[Y==y][x_close_ind]
    x_far = X[Y!=y][x_far_ind]
    return x,x_close,x_far

In [14]:
%timeit generate_triplet(X,Y)


The slowest run took 6.65 times longer than the fastest. This could mean that an intermediate result is being cached 
10000 loops, best of 3: 194 µs per loop

In [4]:
from time import time
t = time()
triplets = [generate_triplet(X[:1000],Y[:1000]) for i in range(10000)]
triplets_test = [generate_triplet(X[1000:],Y[1000:]) for i in range(10000)]
print(time() - t)


19.8850970268

In [5]:
triplets = np.array(triplets)
triplets.shape


Out[5]:
(10000, 3, 64)

In [6]:
from sklearn.decomposition import PCA
pca = PCA(n_components=2)
pca.fit(X)
transform = lambda x: pca.transform(x)[0]

In [7]:
triplet_objective(pca.transform,*triplets[10])


Out[7]:
-12.13665727922605

In [8]:
t =time()
obj = [triplet_objective(transform,*triplet) for triplet in triplets]
print(time() - t)


0.717976093292

In [9]:
np.min(obj),np.max(obj)


Out[9]:
(-32.548446540413366, 49.973049670121881)

In [142]:
def mahalanobis(x,y,diag):
    return np.sqrt(np.dot((x-y)*diag,(x-y)))
    #return np.sqrt(np.sum(np.abs(np.dot(np.outer((x-y),diag),(x-y)))))
    #return np.sqrt(np.dot(np.outer((x-y),np.diag(diag)),x-y))

In [10]:
def weighted_metric(x,y,a,W):
    ''' a is a non-negative vector of length len(x)=len(y), w is a len(x)-by-len(x) real matrix'''
    return np.sqrt(np.dot(np.dot(np.dot(np.dot((x-y).T,np.diag(a)),W),np.diag(a).T),(x-y)))

In [13]:
weighted_metric(np.ones(2),np.ones(2)*2,np.ones(2),np.ones((2,2)))


Out[13]:
2.0

In [14]:
%timeit weighted_metric(np.ones(2),np.ones(2)*2,np.ones(2),np.ones((2,2)))


10000 loops, best of 3: 37.5 µs per loop

In [ ]:


In [143]:
mahalanobis(np.ones(2),np.ones(2),np.ones(2))


Out[143]:
0.0

In [ ]:
pca.transform(

In [146]:
mahalanobis(pca.transform(X[0])[0],pca.transform(X[1])[0],np.ones(2))


Out[146]:
43.042041614801654

In [47]:
pca.transform(X[0])


Out[47]:
array([[  1.25946645, -21.27488348]])

In [147]:
def mahalanobis_obj(weights,triplet,triplet_loss=modified_triplet_loss):
    return triplet_loss(mahalanobis(triplet[0],triplet[1],weights),
                        mahalanobis(triplet[0],triplet[2],weights))

def batch_mahalanobis_obj(weights,triplets,triplet_loss=modified_triplet_loss):
    loss = 0
    for i in range(len(triplets)):
        loss += mahalanobis_obj(weights,triplets[i],triplet_loss)
    return loss / len(triplets)


print(mahalanobis_obj(np.ones(len(triplets[0][0])),triplets[0]))
print(batch_mahalanobis_obj(np.ones(len(triplets[0][0])),triplets[:100]))


23.0999475116
14.2518042666

In [148]:
grad(lambda w:mahalanobis_obj(w,triplets[0]))(np.ones(len(triplets[0][0])))


Out[148]:
array([ 0.        ,  0.        ,  0.25671086, -0.22016996, -0.08341364,
        0.01426944, -0.04613841, -0.01953662,  0.        ,  0.        ,
        0.23717424, -1.86124575, -1.57219751, -0.01953662,  0.02153712,
        0.        ,  0.        ,  0.01026843, -0.39599951,  0.16429495,
        0.07287929, -0.05587501,  0.01026843,  0.        ,  0.        ,
        0.        ,  1.47865453,  1.40050807,  0.08088131, -0.45414094,
       -0.06787803,  0.        ,  0.        , -0.07814647,  0.83174318,
        2.55057271,  2.62871917, -0.27151213, -0.75072277,  0.        ,
        0.        , -0.31258587, -0.27151213,  2.31039771, -0.22016996,
        1.70002725, -0.01953662,  0.        ,  0.        ,  0.16429495,
        0.65591363, -0.74719019,  0.81220656,  1.24248055,  0.        ,
        0.        ,  0.        ,  0.        ,  0.29151717, -0.91622048,
        2.62871917,  0.36966363,  0.        ,  0.        ])

In [547]:
def sgd(objective,dataset,init_point,n_iter=100,step_size=0.01,seed=0,
        stoch_select=False,norm=False,store_intermediates=True):
    ''' objective takes in a parameter vector and an array of data'''
    np.random.seed(seed)
    x = init_point
    if store_intermediates:
        testpoints = np.zeros((n_iter,len(init_point)))
        testpoints[0] = x

    for i in range(1,n_iter):
        ind = np.random.randint(len(dataset))
        obj_grad = grad(lambda p:objective(p,dataset[ind]))
        raw_grad = obj_grad(x)
        #raw_grad = obj_grad(testpoints[i-1])
        gradient = np.nan_to_num(raw_grad)
        x = x-gradient*step_size
        if norm:
            x = np.abs(x) / np.sum(np.abs(x))
        #print(gradient,raw_grad)
        if store_intermediates:
            testpoints[i] = x
    return np.array(testpoints)

In [173]:
results = sgd(mahalanobis_obj,triplets,np.ones(len(triplets[0][0])),n_iter=10000,step_size=0.001)
plt.plot(results[:,:10]);



In [174]:
plt.plot(results);



In [175]:
len(triplets[0][0])


Out[175]:
64

In [176]:
batch_mahalanobis_obj(np.ones(64),triplets,zero_one_triplet_loss)


Out[176]:
0.1188

In [177]:
batch_mahalanobis_obj(results[-1],triplets,zero_one_triplet_loss)


Out[177]:
0.0016000000000000001

In [178]:
batch_mahalanobis_obj(np.ones(64),triplets_test,zero_one_triplet_loss),batch_mahalanobis_obj(results[-1],triplets_test,zero_one_triplet_loss)


Out[178]:
(0.1149, 0.0023999999999999998)

In [179]:
grad(lambda w:batch_mahalanobis_obj(w,triplets[:100]))(np.ones(len(triplets[0][0])))


Out[179]:
array([  0.00000000e+00,  -1.39697931e-03,   1.93212827e-01,
         9.36037000e-02,  -2.18671192e-02,   2.76214612e-01,
         2.37870490e-02,   2.35728505e-03,   8.39181358e-05,
         7.15165796e-02,   2.02456880e-01,   6.79819088e-02,
         5.27159976e-02,   1.75685569e-01,  -1.43022042e-02,
         8.70477997e-03,   0.00000000e+00,   9.58738064e-02,
         4.04692027e-01,   2.39708052e-01,   2.32050097e-01,
         4.58627400e-01,   9.05272474e-03,   6.37398746e-04,
         0.00000000e+00,   8.48901422e-02,   3.77995955e-01,
         2.99422449e-01,   2.31194348e-01,   1.61737823e-01,
         5.18772763e-02,   0.00000000e+00,   0.00000000e+00,
         8.14545187e-02,   3.77108939e-01,   2.45308749e-01,
         2.12929678e-01,   7.15265485e-02,   7.73420477e-02,
         0.00000000e+00,   1.36757132e-03,   2.59236001e-02,
         6.07959753e-01,   4.05438589e-01,   3.48175160e-01,
        -4.21190563e-03,   1.39783702e-01,  -4.66713294e-04,
         8.54732075e-05,  -4.46887853e-02,   2.65310950e-01,
        -3.09890274e-02,   5.75891139e-02,   1.35251184e-01,
         5.98376330e-02,  -1.62749887e-02,   0.00000000e+00,
         2.93990289e-03,   2.09338149e-01,   1.01533911e-01,
         3.11074969e-02,   1.70908476e-01,   1.49561133e-02,
         3.78940054e-03])

In [121]:
cheap_triplet_batch = lambda batch_size=50: triplets[np.random.randint(0,len(triplets),batch_size)]

In [66]:
%timeit cheap_triplet_batch()


The slowest run took 11.05 times longer than the fastest. This could mean that an intermediate result is being cached 
100000 loops, best of 3: 13.5 µs per loop

In [134]:
example_triplet_batch = cheap_triplet_batch()
#stoch_mahalanobis_grad = grad(lambda weights:batch_mahalanobis_obj(weights,cheap_triplet_batch()))
mahalanobis_grad = grad(lambda weights:batch_mahalanobis_obj(weights,example_triplet_batch))

In [180]:
mahalanobis_grad(np.ones(64))


Out[180]:
array([  0.00000000e+00,  -6.96726310e-03,   3.17271575e-01,
         2.54622029e-01,  -9.71341500e-02,   1.45346684e-01,
         1.70738525e-01,  -1.85897354e-02,  -2.12905602e-03,
        -1.11509717e-02,   4.55695827e-01,   6.62022204e-03,
         6.71817729e-02,   2.94194082e-01,   6.32200166e-02,
         5.92929531e-03,   0.00000000e+00,  -1.84055807e-02,
         8.13803529e-02,   6.29166708e-02,   2.24273293e-01,
         3.05723228e-01,  -3.46408863e-02,   2.25839244e-03,
         0.00000000e+00,   4.68715055e-02,   3.81600678e-01,
         2.66462450e-01,   2.50936457e-01,   6.26198973e-02,
         8.62191936e-02,  -1.00776028e-04,   0.00000000e+00,
         2.16081843e-01,   2.95408092e-01,   1.93225736e-01,
         2.51780631e-01,   1.38675144e-01,   5.51463743e-02,
         0.00000000e+00,   1.58750159e-04,   1.32676238e-01,
         6.55123063e-01,   4.16378105e-01,  -1.04148488e-02,
        -7.09007307e-02,   3.57721105e-02,   5.78278405e-04,
         0.00000000e+00,  -4.94821999e-03,   2.07396818e-01,
         2.18731103e-01,  -1.60465160e-01,   1.11217819e-01,
         1.38612110e-01,   5.69644481e-03,   0.00000000e+00,
        -7.04788268e-03,   3.82920488e-01,   4.05046723e-01,
         1.27811141e-01,   1.19607147e-01,   2.56915223e-01,
         5.30695802e-03])

In [181]:
from scipy.optimize import minimize

result = minimize(lambda w:batch_mahalanobis_obj(w,cheap_triplet_batch()),np.ones(64),jac=mahalanobis_grad)


---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-181-934a8b9d3c7a> in <module>()
      1 from scipy.optimize import minimize
      2 
----> 3 result = minimize(lambda w:batch_mahalanobis_obj(w,cheap_triplet_batch()),np.ones(64),jac=mahalanobis_grad)

/Users/joshuafass/anaconda/envs/py27/lib/python2.7/site-packages/scipy/optimize/_minimize.pyc in minimize(fun, x0, args, method, jac, hess, hessp, bounds, constraints, tol, callback, options)
    417         return _minimize_cg(fun, x0, args, jac, callback, **options)
    418     elif meth == 'bfgs':
--> 419         return _minimize_bfgs(fun, x0, args, jac, callback, **options)
    420     elif meth == 'newton-cg':
    421         return _minimize_newtoncg(fun, x0, args, jac, hess, hessp, callback,

/Users/joshuafass/anaconda/envs/py27/lib/python2.7/site-packages/scipy/optimize/optimize.pyc in _minimize_bfgs(fun, x0, args, jac, callback, gtol, norm, eps, maxiter, disp, return_all, **unknown_options)
    840     I = numpy.eye(N, dtype=int)
    841     Hk = I
--> 842     old_fval = f(x0)
    843     old_old_fval = None
    844     xk = x0

/Users/joshuafass/anaconda/envs/py27/lib/python2.7/site-packages/scipy/optimize/optimize.pyc in function_wrapper(*wrapper_args)
    280     def function_wrapper(*wrapper_args):
    281         ncalls[0] += 1
--> 282         return function(*(wrapper_args + args))
    283 
    284     return ncalls, function_wrapper

<ipython-input-181-934a8b9d3c7a> in <lambda>(w)
      1 from scipy.optimize import minimize
      2 
----> 3 result = minimize(lambda w:batch_mahalanobis_obj(w,cheap_triplet_batch()),np.ones(64),jac=mahalanobis_grad)

<ipython-input-121-2b466cb5a3cc> in <lambda>(batch_size)
----> 1 cheap_triplet_batch = lambda batch_size=50: triplets[np.random.randint(0,len(triplets),batch_size)]

TypeError: only integer arrays with one element can be converted to an index

In [83]:
result


Out[83]:
   status: 2
  success: False
     njev: 42
     nfev: 42
 hess_inv: array([[1, 0, 0, ..., 0, 0, 0],
       [0, 1, 0, ..., 0, 0, 0],
       [0, 0, 1, ..., 0, 0, 0],
       ..., 
       [0, 0, 0, ..., 1, 0, 0],
       [0, 0, 0, ..., 0, 1, 0],
       [0, 0, 0, ..., 0, 0, 1]])
      fun: nan
        x: array([  1.00000000e+00,  -2.04972142e+00,  -3.28804201e+02,
        -8.20758617e+01,   1.82302864e+02,  -5.19554039e+02,
        -1.68789951e+02,   9.33535365e+00,   1.23800285e+00,
        -1.06486690e+02,   4.32413100e+01,  -3.01148889e+01,
        -1.16471428e+01,  -2.23832204e+02,  -2.62675480e+01,
        -1.39457164e+00,   1.00000000e+00,  -7.95163061e+01,
        -1.84097946e+02,  -6.53462261e+01,  -3.22324761e+02,
        -3.30967046e+02,   4.62477776e+01,   2.86011480e-01,
         1.00000000e+00,   3.67650376e+01,  -2.90868142e+02,
        -1.75700188e+02,  -2.72414512e+02,  -5.39388205e+02,
        -1.23128143e+02,   1.00000000e+00,   1.00000000e+00,
        -7.32037147e+01,  -4.40033279e+02,  -9.01290311e+01,
        -3.98978109e+02,  -4.78361164e+02,  -5.16434362e+01,
         1.00000000e+00,   1.27849129e+00,   1.97984501e+01,
        -4.21577674e+02,  -2.22017404e+02,  -2.58640558e+02,
         9.07126461e+01,  -1.16832710e+02,   1.00000000e+00,
         1.00000000e+00,   2.65124224e+01,   2.73738467e+01,
        -7.73897198e+01,  -1.16755775e+02,  -1.46423183e+02,
        -1.61476287e+02,  -8.00288203e+00,   1.00000000e+00,
         4.94416707e+00,  -2.71559786e+02,   1.13369825e+01,
        -4.84722064e+02,  -4.30802580e+02,  -4.84629807e+01,
        -5.46283541e+00])
  message: 'Desired error not necessarily achieved due to precision loss.'
      jac: array([ nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan,
        nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan,
        nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan,
        nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan,
        nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan,
        nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan])

In [85]:
plt.plot(result.x)


Out[85]:
[<matplotlib.lines.Line2D at 0x16f73e310>]

In [118]:
mahalanobis(triplets[0][0],triplets[0][1],result.x)


Out[118]:
1092.8740552667703

In [109]:
x,y=triplets[0][1],triplets[0][0]
np.sum(np.abs(np.dot(np.outer((x-y),result.x),(x-y))))


Out[109]:
1194373.7006752358

In [106]:
np.outer(x-y,x-y).shape


Out[106]:
(64, 64)

In [117]:
batch_mahalanobis_obj(result.x,triplets[:10])


Out[117]:
302.77880056006222

In [2]:
def feedforward_network

In [5]:
import keras

In [182]:
def objective(M):
    return np.sum(np.abs(M))

grad(objective)(np.ones((10,10)))


Out[182]:
array([[ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
       [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
       [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
       [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
       [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
       [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
       [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
       [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
       [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
       [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.]])

In [183]:
def tripletify_trajectory(X,tau_1=5,tau_2=20):
    X_triplets = []
    for i in range(len(X) - tau_2):
        X_triplets.append((X[i],X[i+tau_1],X[i+tau_2]))
    return X_triplets

In [186]:
X_ = tripletify_trajectory(X)


Out[186]:
(array([  0.,   0.,   5.,  13.,   9.,   1.,   0.,   0.,   0.,   0.,  13.,
         15.,  10.,  15.,   5.,   0.,   0.,   3.,  15.,   2.,   0.,  11.,
          8.,   0.,   0.,   4.,  12.,   0.,   0.,   8.,   8.,   0.,   0.,
          5.,   8.,   0.,   0.,   9.,   8.,   0.,   0.,   4.,  11.,   0.,
          1.,  12.,   7.,   0.,   0.,   2.,  14.,   5.,  10.,  12.,   0.,
          0.,   0.,   0.,   6.,  13.,  10.,   0.,   0.,   0.]),
 array([  0.,   0.,  12.,  10.,   0.,   0.,   0.,   0.,   0.,   0.,  14.,
         16.,  16.,  14.,   0.,   0.,   0.,   0.,  13.,  16.,  15.,  10.,
          1.,   0.,   0.,   0.,  11.,  16.,  16.,   7.,   0.,   0.,   0.,
          0.,   0.,   4.,   7.,  16.,   7.,   0.,   0.,   0.,   0.,   0.,
          4.,  16.,   9.,   0.,   0.,   0.,   5.,   4.,  12.,  16.,   4.,
          0.,   0.,   0.,   9.,  16.,  16.,  10.,   0.,   0.]),
 array([  0.,   0.,   3.,  13.,  11.,   7.,   0.,   0.,   0.,   0.,  11.,
         16.,  16.,  16.,   2.,   0.,   0.,   4.,  16.,   9.,   1.,  14.,
          2.,   0.,   0.,   4.,  16.,   0.,   0.,  16.,   2.,   0.,   0.,
          0.,  16.,   1.,   0.,  12.,   8.,   0.,   0.,   0.,  15.,   9.,
          0.,  13.,   6.,   0.,   0.,   0.,   9.,  14.,   9.,  14.,   1.,
          0.,   0.,   0.,   2.,  12.,  13.,   4.,   0.,   0.]))

In [198]:
def deviation_ify_triplets(triplets):
    deviations = np.zeros((len(triplets),2,len(triplets[0][0])))
    for i,(a,b,c) in enumerate(triplets):
        deviations[i][0] = a-b
        deviations[i][1] = a-c
    return deviations

In [199]:
deviations = deviation_ify_triplets(X_)

In [324]:
def weighted_objective(weights,deviation_pair,norm=False):
    close = np.dot(deviation_pair[0],np.abs(weights))
    far = np.dot(deviation_pair[1],np.abs(weights))
    if norm:
        return (close-far) / (close+far)
    else:
        return close-far

In [289]:
weighted_objective(np.ones(64),deviations[0])


Out[289]:
0.054945054945054944

In [290]:
stoch_objective = lambda weights:weighted_objective(weights,deviations[np.random.randint(len(deviations))])

In [291]:
stoch_objective(np.ones(64))


Out[291]:
3.0

In [334]:
results = sgd(weighted_objective,deviations,np.ones(64),n_iter=10000,step_size=0.001)

In [335]:
plt.plot(results);



In [336]:
from msmbuilder.example_datasets import AlanineDipeptide,FsPeptide
ala = AlanineDipeptide().get()
ala_traj = ala.trajectories[0]
fs = FsPeptide().get()
fs_traj = fs.trajectories[0]


loading trajectory_1.xtc...
loading trajectory_10.xtc...
loading trajectory_11.xtc...
loading trajectory_12.xtc...
loading trajectory_13.xtc...
loading trajectory_14.xtc...
loading trajectory_15.xtc...
loading trajectory_16.xtc...
loading trajectory_17.xtc...
loading trajectory_18.xtc...
loading trajectory_19.xtc...
loading trajectory_2.xtc...
loading trajectory_20.xtc...
loading trajectory_21.xtc...
loading trajectory_22.xtc...
loading trajectory_23.xtc...
loading trajectory_24.xtc...
loading trajectory_25.xtc...
loading trajectory_26.xtc...
loading trajectory_27.xtc...
loading trajectory_28.xtc...
loading trajectory_3.xtc...
loading trajectory_4.xtc...
loading trajectory_5.xtc...
loading trajectory_6.xtc...
loading trajectory_7.xtc...
loading trajectory_8.xtc...
loading trajectory_9.xtc...

In [295]:
ala_traj


Out[295]:
<mdtraj.Trajectory with 9999 frames, 22 atoms, 3 residues, without unitcells at 0x3771693d0>

In [296]:
from msmbuilder.featurizer import DihedralFeaturizer
dhf = DihedralFeaturizer()
dhft = dhf.transform([ala_traj])[0]

In [297]:
dhft.shape


Out[297]:
(9999, 4)

In [364]:
deviations_ala = deviation_ify_triplets(tripletify_trajectory(dhft,tau_1=5,tau_2=10))

In [365]:
results = sgd(weighted_objective,deviations_ala,np.ones(4),n_iter=10000,step_size=0.01)

In [366]:
plt.plot(results);



In [273]:
import sys
sys.path.append('../projects/metric-learning')
import weighted_rmsd

In [301]:
from weighted_rmsd import compute_atomwise_deviation,compute_atomwise_deviation_xyz

In [367]:
ala_triplets = tripletify_trajectory(ala_traj,tau_1=5,tau_2=10)
fs_triplets = tripletify_trajectory(fs_traj,tau_1=5,tau_2=10)

In [303]:
a = ala_triplets[0][0]
a.n_atoms


Out[303]:
22

In [304]:
def deviation_ify_protein_triplets(triplets):
    deviations = np.zeros((len(triplets),2,triplets[0][0].n_atoms))
    for i,(a,b,c) in enumerate(triplets):
        deviations[i][0] = compute_atomwise_deviation(a,b)
        deviations[i][1] = compute_atomwise_deviation(a,c)
    return deviations

In [514]:
ala_deviations = deviation_ify_protein_triplets(ala_triplets)
fs_deviations = deviation_ify_protein_triplets(fs_triplets)

In [515]:
fs_atoms = fs_deviations.shape[-1]

In [516]:
results_ala = sgd(weighted_objective,ala_deviations,np.ones(22),n_iter=10000,step_size=0.1)
results_fs = sgd(weighted_objective,fs_deviations,np.ones(fs_atoms),n_iter=10000,step_size=0.1)

In [517]:
norm_results_ala = 22*results_ala / np.outer(results_ala.sum(1),np.ones(results_ala.shape[1]))
norm_results_fs = fs_atoms*results_fs / np.outer(results_fs.sum(1),np.ones(fs_atoms))

In [518]:
plt.plot(norm_results_ala);
plt.figure()
plt.plot(norm_results_fs);



In [519]:
def zero_one_weighted_deviation_loss(weights,deviation_pair):
    return 1.0*(np.dot(deviation_pair[0],weights) > np.dot(deviation_pair[1],weights))

In [520]:
sum([zero_one_weighted_deviation_loss(np.ones(22),a) for a in ala_deviations]) / len(ala_deviations)


Out[520]:
0.4206627290019021

In [521]:
sum([zero_one_weighted_deviation_loss(np.abs(results_ala[-1]),a) for a in ala_deviations]) / len(ala_deviations)


Out[521]:
0.40955050555611172

In [522]:
sum([zero_one_weighted_deviation_loss(np.ones(fs_atoms),a) for a in fs_deviations]) / len(fs_deviations)


Out[522]:
0.31621621621621621

In [523]:
sum([zero_one_weighted_deviation_loss(np.abs(results_fs[-1]),a) for a in fs_deviations]) / len(fs_deviations)


Out[523]:
0.30760760760760758

In [363]:
from weighted_rmsd import compute_kinetic_weights

In [ ]:
compute_kinetic_weights(tra

In [319]:
print('residues with increased weight here: ',np.arange(len(norm_results))[norm_results[-1]>1])


('residues with increased weight here: ', array([11, 12, 13, 15]))

In [ ]:


In [372]:
from MDAnalysis.analysis import align

In [484]:
a = np.array(fs_traj[0].xyz[0],dtype=np.float64)
b = np.array(fs_traj[1].xyz[0],dtype=np.float64)
a


Out[484]:
array([[ 1.11600006, -1.08800006,  1.04900002],
       [ 1.13700008, -0.96800005,  1.028     ],
       [ 1.05200005, -1.12600005,  1.18000007],
       [ 1.01600003, -1.03800011,  1.23200011],
       [ 0.97300005, -1.20000005,  1.16300011],
       [ 1.12100005, -1.18000007,  1.24600005],
       [ 1.12800002, -1.18200004,  0.95800006],
       [ 1.13200009, -1.27800012,  0.98900002],
       [ 1.22500002, -1.15900004,  0.83800006],
       [ 1.31500006, -1.10600007,  0.86900002],
       [ 1.2750001 , -1.29900002,  0.79700005],
       [ 1.32800007, -1.34100008,  0.88200003],
       [ 1.34800005, -1.27900004,  0.71900004],
       [ 1.18800008, -1.36100006,  0.77300006],
       [ 1.15900004, -1.07600009,  0.72700006],
       [ 1.22800004, -0.99300003,  0.66300005],
       [ 1.02900004, -1.10900009,  0.70700002],
       [ 0.97500002, -1.1730001 ,  0.76300001],
       [ 0.95600003, -1.06400001,  0.58700001],
       [ 1.01800001, -1.0660001 ,  0.49800003],
       [ 0.84000003, -1.16500008,  0.57100004],
       [ 0.88500005, -1.26100004,  0.54400003],
       [ 0.77300006, -1.171     ,  0.65700001],
       [ 0.78000003, -1.13300002,  0.48500001],
       [ 0.91900003, -0.91500002,  0.59300005],
       [ 0.94300002, -0.84100002,  0.49900001],
       [ 0.87200004, -0.87300003,  0.71100003],
       [ 0.85600007, -0.94500005,  0.78000003],
       [ 0.81300002, -0.74200004,  0.74200004],
       [ 0.75400001, -0.74900001,  0.83300006],
       [ 0.92300004, -0.648     ,  0.76500005],
       [ 0.97900003, -0.625     ,  0.67400002],
       [ 0.88600004, -0.55400002,  0.80700004],
       [ 0.99000007, -0.68500006,  0.84200007],
       [ 0.70000005, -0.69600004,  0.64200002],
       [ 0.69000006, -0.574     ,  0.61300004],
       [ 0.63000005, -0.79000002,  0.57800001],
       [ 0.62900001, -0.88200003,  0.62100005],
       [ 0.52000004, -0.76700002,  0.48000002],
       [ 0.55400002, -0.68700004,  0.41300002],
       [ 0.50300002, -0.89100003,  0.39100003],
       [ 0.59800005, -0.91800004,  0.34600002],
       [ 0.47600001, -0.98100007,  0.44600001],
       [ 0.42800003, -0.87800002,  0.31300002],
       [ 0.39000002, -0.71900004,  0.54100001],
       [ 0.28500003, -0.78200006,  0.52600002],
       [ 0.38300002, -0.59400004,  0.58900005],
       [ 0.46900001, -0.54300004,  0.58000004],
       [ 0.259     , -0.52000004,  0.62300003],
       [ 0.192     , -0.58100003,  0.68400002],
       [ 0.30000001, -0.39700001,  0.708     ],
       [ 0.21400002, -0.33200002,  0.72200006],
       [ 0.384     , -0.34500003,  0.66100001],
       [ 0.33400002, -0.42400002,  0.80800003],
       [ 0.19800001, -0.45900002,  0.49500003],
       [ 0.27600002, -0.42100003,  0.40300003],
       [ 0.066     , -0.45000002,  0.48300001],
       [ 0.01      , -0.47100002,  0.56400001],
       [-0.008     , -0.45400003,  0.35800001],
       [ 0.04      , -0.53200001,  0.30000001],
       [-0.142     , -0.50300002,  0.39400002],
       [-0.132     , -0.59200001,  0.45600003],
       [-0.19600001, -0.53000003,  0.303     ],
       [-0.19500001, -0.42300001,  0.44600001],
       [-0.008     , -0.324     ,  0.28200001],
       [-0.017     , -0.32600001,  0.162     ],
       [ 0.016     , -0.208     ,  0.34600002],
       [ 0.032     , -0.21000001,  0.44600001],
       [ 0.03      , -0.068     ,  0.28800002],
       [-0.071     , -0.043     ,  0.25500003],
       [ 0.072     ,  0.023     ,  0.40100002],
       [ 0.16700001, -0.004     ,  0.44700003],
       [-0.006     ,  0.023     ,  0.47800002],
       [ 0.083     ,  0.126     ,  0.36700001],
       [ 0.11700001, -0.056     ,  0.171     ],
       [ 0.082     ,  0.025     ,  0.08400001],
       [ 0.22700001, -0.13500001,  0.16700001],
       [ 0.24400002, -0.18300001,  0.25500003],
       [ 0.32300001, -0.15700001,  0.062     ],
       [ 0.39200002, -0.072     ,  0.06500001],
       [ 0.40600002, -0.28500003,  0.09500001],
       [ 0.48100004, -0.30500001,  0.019     ],
       [ 0.46000001, -0.266     ,  0.18800001],
       [ 0.33800003, -0.36900002,  0.10600001],
       [ 0.26500002, -0.16900001, -0.08400001],
       [ 0.32900003, -0.125     , -0.18400002],
       [ 0.13900001, -0.21700001, -0.093     ],
       [ 0.109     , -0.259     , -0.007     ],
       [ 0.055     , -0.223     , -0.215     ],
       [ 0.108     , -0.28300002, -0.289     ],
       [-0.067     , -0.29900002, -0.17      ],
       [-0.13100001, -0.24900001, -0.098     ],
       [-0.032     , -0.39100003, -0.123     ],
       [-0.15200001, -0.34900001, -0.28600001],
       [-0.08400001, -0.39800003, -0.35600001],
       [-0.19900002, -0.266     , -0.33900002],
       [-0.26500002, -0.44000003, -0.252     ],
       [-0.324     , -0.44500002, -0.34300002],
       [-0.31500003, -0.40100002, -0.163     ],
       [-0.21800001, -0.57800001, -0.23400001],
       [-0.17200001, -0.62200004, -0.31200001],
       [-0.21900001, -0.648     , -0.115     ],
       [-0.17      , -0.76700002, -0.12      ],
       [-0.13900001, -0.80500007, -0.20900001],
       [-0.163     , -0.82800007, -0.04      ],
       [-0.26500002, -0.59600002, -0.004     ],
       [-0.30000001, -0.50200003, -0.005     ],
       [-0.24200001, -0.63500005,  0.086     ],
       [ 0.021     , -0.085     , -0.28100002],
       [-0.022     , -0.07700001, -0.39600003],
       [ 0.036     ,  0.025     , -0.21000001],
       [ 0.063     ,  0.021     , -0.112     ],
       [-0.017     ,  0.156     , -0.24200001],
       [-0.002     ,  0.18100001, -0.347     ],
       [-0.17300001,  0.163     , -0.22100002],
       [-0.208     ,  0.19500001, -0.123     ],
       [-0.21700001,  0.23100001, -0.294     ],
       [-0.21300001,  0.063     , -0.23800001],
       [ 0.059     ,  0.27200001, -0.18000001],
       [ 0.021     ,  0.38700002, -0.20300001],
       [ 0.17900001,  0.252     , -0.11800001],
       [ 0.21700001,  0.16000001, -0.10700001],
       [ 0.24400002,  0.36200002, -0.039     ],
       [ 0.16600001,  0.42900002, -0.003     ],
       [ 0.31200001,  0.30500001,  0.085     ],
       [ 0.34100002,  0.38900003,  0.149     ],
       [ 0.252     ,  0.22600001,  0.132     ],
       [ 0.40500003,  0.257     ,  0.056     ],
       [ 0.33200002,  0.45200002, -0.12800001],
       [ 0.45300001,  0.45600003, -0.108     ],
       [ 0.27200001,  0.51300001, -0.23600002],
       [ 0.17200001,  0.50400001, -0.245     ],
       [ 0.35300002,  0.59500003, -0.333     ],
       [ 0.44200003,  0.62900001, -0.28100002],
       [ 0.38900003,  0.48500001, -0.44400004],
       [ 0.48800004,  0.52000004, -0.47400004],
       [ 0.41000003,  0.38800001, -0.40100002],
       [ 0.32600001,  0.49100003, -0.53400004],
       [ 0.27500001,  0.71700001, -0.38800001],
       [ 0.155     ,  0.71700001, -0.38900003],
       [ 0.34300002,  0.82200003, -0.42400002],
       [ 0.44400004,  0.82400006, -0.42800003],
       [ 0.27000001,  0.94100004, -0.46900001],
       [ 0.21300001,  0.98500007, -0.38600001],
       [ 0.37400001,  1.04500008, -0.523     ],
       [ 0.32900003,  1.13900006, -0.55600005],
       [ 0.43300003,  1.        , -0.60400003],
       [ 0.44800001,  1.05800009, -0.44400004],
       [ 0.16600001,  0.92500007, -0.58700001],
       [ 0.067     ,  1.00300002, -0.59300005],
       [ 0.19000001,  0.83100003, -0.68200004],
       [ 0.26800001,  0.76800001, -0.69100004],
       [ 0.08000001,  0.79100001, -0.78100002],
       [ 0.034     ,  0.87300003, -0.83700001],
       [ 0.149     ,  0.69800001, -0.88500005],
       [ 0.24600001,  0.74100006, -0.90900004],
       [ 0.16600001,  0.597     , -0.84600002],
       [ 0.06900001,  0.68100005, -1.00999999],
       [-0.021     ,  0.62300003, -0.98900002],
       [ 0.029     ,  0.77400005, -1.0480001 ],
       [ 0.14      ,  0.62600005, -1.12600005],
       [ 0.078     ,  0.58100003, -1.20400012],
       [ 0.2       ,  0.54300004, -1.08900011],
       [ 0.24400002,  0.71400005, -1.18100011],
       [ 0.33200002,  0.71400005, -1.13100004],
       [ 0.23200001,  0.80300003, -1.27200007],
       [ 0.32500002,  0.89400005, -1.29900002],
       [ 0.41300002,  0.89400005, -1.25100005],
       [ 0.30100003,  0.96000004, -1.3720001 ],
       [ 0.14      ,  0.78100002, -1.36500001],
       [ 0.08000001,  0.70100003, -1.35200012],
       [ 0.11700001,  0.84500003, -1.44000006],
       [-0.05      ,  0.72700006, -0.71800005],
       [-0.15200001,  0.71700001, -0.79000002],
       [-0.053     ,  0.69500005, -0.58700001],
       [ 0.034     ,  0.71200001, -0.53800005],
       [-0.15800001,  0.61300004, -0.52500004],
       [-0.24900001,  0.62700003, -0.58400005],
       [-0.11400001,  0.46500003, -0.52500004],
       [-0.063     ,  0.43100002, -0.61500001],
       [-0.056     ,  0.44700003, -0.43500003],
       [-0.20300001,  0.40400001, -0.51700002],
       [-0.208     ,  0.66000003, -0.38000003],
       [-0.30000001,  0.60200006, -0.324     ],
       [-0.134     ,  0.74700004, -0.31600001],
       [-0.041     ,  0.76800001, -0.35200003],
       [-0.15400001,  0.81400001, -0.18900001],
       [-0.06500001,  0.86800003, -0.15700001],
       [-0.26700002,  0.91600007, -0.20900001],
       [-0.354     ,  0.88500005, -0.266     ],
       [-0.31200001,  0.94700003, -0.115     ],
       [-0.22000001,  1.00100005, -0.259     ],
       [-0.16900001,  0.72700006, -0.059     ],
       [-0.25800002,  0.75500005,  0.016     ],
       [-0.094     ,  0.61400002, -0.051     ],
       [-0.026     ,  0.59900004, -0.123     ],
       [-0.104     ,  0.514     ,  0.058     ],
       [-0.055     ,  0.42200002,  0.024     ],
       [-0.03      ,  0.56700003,  0.17500001],
       [ 0.07      ,  0.597     ,  0.14600001],
       [-0.016     ,  0.49300003,  0.25400001],
       [-0.085     ,  0.64700001,  0.22400001],
       [-0.24900001,  0.47100002,  0.09100001],
       [-0.29200003,  0.45200002,  0.20300001],
       [-0.33200002,  0.44400004, -0.014     ],
       [-0.29100001,  0.44900003, -0.10700001],
       [-0.47900003,  0.42000002,  0.003     ],
       [-0.52500004,  0.51500005,  0.027     ],
       [-0.53600001,  0.38600001, -0.13600001],
       [-0.53500003,  0.47800002, -0.19400001],
       [-0.64000005,  0.35500002, -0.126     ],
       [-0.47400004,  0.31200001, -0.18700001],
       [-0.53200001,  0.324     ,  0.109     ],
       [-0.634     ,  0.347     ,  0.17600001],
       [-0.46000001,  0.21300001,  0.14300001],
       [-0.37400001,  0.19600001,  0.09200001],
       [-0.49600002,  0.116     ,  0.25      ],
       [-0.59600002,  0.081     ,  0.22500001],
       [-0.40800002, -0.015     ,  0.23900001],
       [-0.44300002, -0.08400001,  0.31600001],
       [-0.30500001,  0.009     ,  0.264     ],
       [-0.40000001, -0.089     ,  0.1       ],
       [-0.35700002, -0.017     ,  0.03      ],
       [-0.34500003, -0.18100001,  0.11700001],
       [-0.528     , -0.13900001,  0.035     ],
       [-0.5       , -0.162     , -0.068     ],
       [-0.59800005, -0.055     ,  0.044     ],
       [-0.58900005, -0.25600001,  0.10200001],
       [-0.53800005, -0.30500001,  0.17400001],
       [-0.70300001, -0.31      ,  0.068     ],
       [-0.73900002, -0.43000001,  0.098     ],
       [-0.70000005, -0.47900003,  0.178     ],
       [-0.82900006, -0.46200001,  0.06500001],
       [-0.78700006, -0.257     , -0.011     ],
       [-0.76600003, -0.16500001, -0.047     ],
       [-0.86400002, -0.30900002, -0.05      ],
       [-0.50100005,  0.171     ,  0.39700001],
       [-0.56      ,  0.11100001,  0.48300001],
       [-0.42300001,  0.27700001,  0.42100003],
       [-0.39200002,  0.32900003,  0.34      ],
       [-0.40200001,  0.35000002,  0.54800004],
       [-0.41900003,  0.27200001,  0.62200004],
       [-0.25300002,  0.38500002,  0.55200005],
       [-0.19600001,  0.29300001,  0.53800005],
       [-0.24800001,  0.45800003,  0.47200003],
       [-0.215     ,  0.43900001,  0.63900006],
       [-0.49200001,  0.47600001,  0.56600004],
       [-0.49500003,  0.53100002,  0.67600006],
       [-0.55900002,  0.53100002,  0.45900002],
       [-0.55700004,  0.48900002,  0.36700001],
       [-0.63000005,  0.65200001,  0.46100003],
       [-0.57000005,  0.72400004,  0.51700002],
       [-0.64500004,  0.71100003,  0.32000002],
       [-0.70100003,  0.80400002,  0.33400002],
       [-0.54800004,  0.73200005,  0.27500001],
       [-0.71500003,  0.64500004,  0.26800001],
       [-0.77300006,  0.648     ,  0.537     ],
       [-0.86900002,  0.58500004,  0.48400003],
       [-0.77500004,  0.69600004,  0.66300005],
       [-0.68600005,  0.72800004,  0.69800001],
       [-0.88600004,  0.69300002,  0.75500005],
       [-0.86500007,  0.76600003,  0.83300006],
       [-0.97200006,  0.72600001,  0.69600004],
       [-0.90200007,  0.60200006,  0.81200004]])

In [410]:
align.rms.rmsd(a,b,np.ones(len(a)))


Out[410]:
0.1967593781226804

In [412]:
grad(lambda w:align.rms.rmsd(a,b,w))(np.ones(len(a)))


---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-412-66ba4fd44497> in <module>()
----> 1 grad(lambda w:align.rms.rmsd(a,b,w))(np.ones(len(a)))

/Users/joshuafass/anaconda/envs/py27/lib/python2.7/site-packages/autograd/core.pyc in gradfun(*args, **kwargs)
     13     the same type as the argument."""
     14     def gradfun(*args,**kwargs):
---> 15         return backward_pass(*forward_pass(fun,args,kwargs,argnum))
     16 
     17     try:

/Users/joshuafass/anaconda/envs/py27/lib/python2.7/site-packages/autograd/core.pyc in forward_pass(fun, args, kwargs, argnum)
     31         args = list(args)
     32         args[argnum] = merge_tapes(start_node, arg_wrt)
---> 33         end_node = fun(*args, **kwargs)
     34         return start_node, end_node, tape
     35 

<ipython-input-412-66ba4fd44497> in <lambda>(w)
----> 1 grad(lambda w:align.rms.rmsd(a,b,w))(np.ones(len(a)))

/Users/joshuafass/anaconda/envs/py27/lib/python2.7/site-packages/MDAnalysis/analysis/rms.pyc in rmsd(a, b, weights, center)
    169     if weights is not None:
    170         # weights are constructed as relative to the mean
--> 171         relative_weights = numpy.asarray(weights) / numpy.mean(weights)
    172     else:
    173         relative_weights = None

/Users/joshuafass/anaconda/envs/py27/lib/python2.7/site-packages/numpy/core/fromnumeric.pyc in mean(a, axis, dtype, out, keepdims)
   2728         try:
   2729             mean = a.mean
-> 2730             return mean(axis=axis, dtype=dtype, out=out)
   2731         except AttributeError:
   2732             pass

/Users/joshuafass/anaconda/envs/py27/lib/python2.7/site-packages/autograd/core.pyc in __call__(self, *args, **kwargs)
    110             result = new_node(result, tapes)
    111             for tape, argnum, parent in ops:
--> 112                 gradfun = self.gradmaker(argnum, result, *args, **kwargs)
    113                 rnode = result.tapes[tape]
    114                 rnode.parent_grad_ops.append((gradfun, parent))

/Users/joshuafass/anaconda/envs/py27/lib/python2.7/site-packages/autograd/core.pyc in gradmaker(self, argnum, *args, **kwargs)
     77     def gradmaker(self, argnum, *args, **kwargs):
     78         try:
---> 79             return self.grads[argnum](*args, **kwargs)
     80         except KeyError:
     81             if self.grads == {}:

TypeError: make_grad_np_mean() got an unexpected keyword argument 'dtype'

In [413]:
%timeit align.rms.rmsd(a,b)


The slowest run took 12.29 times longer than the fastest. This could mean that an intermediate result is being cached 
100000 loops, best of 3: 13 µs per loop

In [384]:
align.rms.rmsd(a,a,10*np.ones(fs_atoms))


Out[384]:
4.640215302476424e-08

In [ ]:


In [385]:
import numpy.linalg as la

In [389]:
u,s,v=la.svd(np.ones((10,10)))

In [396]:
def can_has_svd_derivative(m):
    u,s,v=la.svd(m)
    return np.sum(u)

In [400]:
plt.hist([can_has_svd_derivative(np.random.randn(10,10)) for i in range(10000)],bins=50);



In [401]:
grad(can_has_svd_derivative)(np.ones((10,10)))


/Users/joshuafass/anaconda/envs/py27/lib/python2.7/site-packages/numpy/linalg/linalg.py:1327: DeprecationWarning: Implicitly casting between incompatible kinds. In a future numpy release, this will raise an error. Use casting="unsafe" if this is intentional.
  u, s, vt = gufunc(a, signature=signature, extobj=extobj)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-401-0d42f9bcfaf5> in <module>()
----> 1 grad(can_has_svd_derivative)(np.ones((10,10)))

/Users/joshuafass/anaconda/envs/py27/lib/python2.7/site-packages/autograd/core.pyc in gradfun(*args, **kwargs)
     13     the same type as the argument."""
     14     def gradfun(*args,**kwargs):
---> 15         return backward_pass(*forward_pass(fun,args,kwargs,argnum))
     16 
     17     try:

/Users/joshuafass/anaconda/envs/py27/lib/python2.7/site-packages/autograd/core.pyc in forward_pass(fun, args, kwargs, argnum)
     31         args = list(args)
     32         args[argnum] = merge_tapes(start_node, arg_wrt)
---> 33         end_node = fun(*args, **kwargs)
     34         return start_node, end_node, tape
     35 

<ipython-input-396-60ea00a7f68c> in can_has_svd_derivative(m)
      1 def can_has_svd_derivative(m):
----> 2     u,s,v=la.svd(m)
      3     return np.sum(u)

/Users/joshuafass/anaconda/envs/py27/lib/python2.7/site-packages/numpy/linalg/linalg.pyc in svd(a, full_matrices, compute_uv)
   1325 
   1326         signature = 'D->DdD' if isComplexType(t) else 'd->ddd'
-> 1327         u, s, vt = gufunc(a, signature=signature, extobj=extobj)
   1328         u = u.astype(result_t)
   1329         s = s.astype(_realType(result_t))

TypeError: float() argument must be a string or a number

In [414]:
def can_has_det_deriv(m):
    return np.linalg.det(m)

In [416]:
np.random.seed(0)
m = np.random.randn(10,10)
np.linalg.det(m)


Out[416]:
55.921478444018824

In [417]:
grad(np.linalg.det)(m)


Out[417]:
array([[  3.26689297e+01,   2.41157719e+02,  -1.36510452e+02,
          2.70083908e+01,  -1.21420666e+02,  -1.02107211e+02,
         -4.34478612e+00,  -1.29692486e+02,  -7.47336897e+01,
          1.91508128e+02],
       [ -5.03241338e+01,  -5.25953301e+02,   2.73235722e+02,
         -5.25773306e+01,   3.30054896e+02,   2.27235382e+02,
          7.36456391e+00,   2.82679054e+02,   1.69057371e+02,
         -4.66279585e+02],
       [ -1.15278778e+01,  -1.05033293e+01,   9.68732507e+00,
         -1.56991627e+00,   4.28698087e+00,  -7.95248784e+00,
          5.58982214e+00,   6.79406395e+00,   4.65312730e+00,
         -2.44705117e+00],
       [ -1.22245932e+01,  -1.57427909e+02,   8.36492880e+01,
         -2.33373075e+01,   7.57112260e+01,   4.55714745e+01,
          2.32892204e+01,   9.70634834e+01,   3.68251151e+01,
         -1.10831414e+02],
       [ -3.23354975e+01,  -2.04417611e+02,   1.19108721e+02,
         -1.37110905e+01,   1.02245409e+02,   6.72908784e+01,
          1.42964657e+01,   1.12127771e+02,   3.69521168e+01,
         -1.58293802e+02],
       [  8.77906319e+01,   1.18130442e+03,  -5.97789993e+02,
          1.19951167e+02,  -6.96867519e+02,  -4.99282907e+02,
         -1.75844838e+01,  -6.31784525e+02,  -3.77927426e+02,
          9.35907137e+02],
       [ -1.56662031e+01,  -4.05643525e+02,   1.69340008e+02,
         -4.88676753e+01,   2.90318763e+02,   1.95103658e+02,
         -2.32252715e+01,   2.15890316e+02,   1.53639625e+02,
         -3.58478140e+02],
       [ -9.00861880e+00,   3.02977311e+01,   3.50673455e+01,
         -1.01684085e-01,  -5.24682539e+01,  -3.96194133e+01,
          8.22745562e+00,  -6.04810379e+00,  -4.49773055e+01,
          3.79142398e+01],
       [ -1.70037340e+01,  -1.85984435e+02,   9.65191552e+01,
         -1.97794741e+01,   1.11337831e+02,   8.73881874e+01,
          7.30362975e+00,   9.36092675e+01,   4.73072127e+01,
         -1.36726656e+02],
       [  2.64486552e+00,   5.00259986e+01,  -8.88307794e+00,
          8.08764356e+00,  -3.04057563e+01,  -1.33929622e+01,
         -8.44105029e+00,  -2.12760228e+00,  -1.29221512e+01,
          3.88371830e+01]])

In [481]:
def BC_w_vec(X,Y,m):
    M = np.diag(m)
    xx = np.dot(np.dot(X.T,M),X)
    xy = np.dot(np.dot(X.T,M),Y)
    yy = np.dot(np.dot(Y.T,M),Y)
    return 1-np.linalg.det(xy) / np.sqrt(np.linalg.det(xx) * np.linalg.det(yy))

In [482]:
BC_w_vec(fs_triplets_xyz[0][0],fs_triplets_xyz[0][1],np.ones(fs_atoms))


Out[482]:
0.1010878238059435

In [485]:
plt.plot(grad(lambda w:BC_w_vec(a,b,w))(np.ones(264)));



In [486]:
%timeit grad(lambda w:BC_w_vec(a,b,w))(np.ones(264))


100 loops, best of 3: 19.7 ms per loop

In [487]:
%timeit BC_w_vec(a,b,np.ones(264))


1000 loops, best of 3: 290 µs per loop

In [428]:
grad_bc = grad(lambda w:BC_w_vec(a,b,w))

In [430]:
%timeit grad_bc(np.ones(264))


100 loops, best of 3: 19.2 ms per loop

In [488]:
losses_unweighted = [zero_one_triplet_loss(BC_w_vec(a.xyz[0],b.xyz[0],np.ones(264)),BC_w_vec(a.xyz[0],c.xyz[0],np.ones(264))) for (a,b,c) in fs_triplets]

In [489]:
sum(losses_unweighted) / len(losses_unweighted)


Out[489]:
0.24194194194194193

In [490]:
def wbc_obj(weights,triplet,triplet_loss=modified_triplet_loss):
    return triplet_loss(BC_w_vec(triplet[0],triplet[1],weights),
                        BC_w_vec(triplet[0],triplet[2],weights))

In [491]:
from msmbuilder.featurizer import RawPositionsFeaturizer as rpf

In [492]:
fs_triplets_xyz = [(a.xyz[0],b.xyz[0],c.xyz[0]) for (a,b,c) in fs_triplets]

In [545]:
t = time()
results = sgd(wbc_obj,fs_triplets_xyz,np.ones(fs_atoms),n_iter=100,step_size=10)
plt.plot(results[:,:10]);
print(time()-t)


2.40043997765

In [ ]:
results = sgd(wbc_obj,fs_triplets_xyz,np.ones(fs_atoms),n_iter=10000,step_size=1,store_intermediates=False)

In [541]:
losses_weighted = [zero_one_triplet_loss(BC_w_vec(a.xyz[0],b.xyz[0],results[-1]),BC_w_vec(a.xyz[0],c.xyz[0],results[-1])) for (a,b,c) in fs_triplets]

In [542]:
sum(losses_weighted) / len(losses_weighted)


Out[542]:
0.26056056056056054

In [472]:
# idea: instead of picking the points randomly, you can pick time-points on either side of detected change-points

In [ ]: