Logistic MF with multiple kernels using TensorFlow and Edward

This is a somewhat more accessible demonstration of VB-MK-LMF with a different variational approximation strategy and slightly modified priors. In particular, this version utilizes BBVI by Ranganath et al. (2013), as implemented in the Edward package by the Blei lab (Tran et al. (2016)). Since Gamma distributions lead to very noisy graidents with BBVI, they have been replaced by LogNormals. We also impose priors on the $\alpha$ params ($L_2$ regularization).


In [1]:
%pylab inline

import edward as ed
from edward.models import Normal, MultivariateNormalTriL, TransformedDistribution, NormalWithSoftplusScale
from edward.models.random_variable import RandomVariable

import tensorflow as tf
from tensorflow.contrib.distributions import Distribution

import numpy as np
from sklearn.metrics import precision_recall_curve, roc_curve, roc_auc_score, auc


Populating the interactive namespace from numpy and matplotlib
/usr/local/lib/python3.5/dist-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters
WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.

In [2]:
# Load interaction matrix
admat = "data/nr/mat/nr_admat_dgc.txt"
with open(admat) as f:
    ncols = len(f.readline().split('\t'))
R_ = np.loadtxt(admat,skiprows=1,usecols=range(1,ncols),delimiter='\t',dtype=np.float32)
I,J = R_.shape

# Load similarity matrices
simmat_u = ["data/nr/mat/nr_simmat_dg.txt"]
Ku = np.array([np.loadtxt(mat,skiprows=1,usecols=range(1,I+1),delimiter='\t',dtype=np.float32) for mat in simmat_u])

simmat_v = ["data/nr/mat/nr_simmat_dc.txt",
            "data/nr/mat/nr_simmat_dc_maccs_rbf.txt",
            "data/nr/mat/nr_simmat_dc_maccs_tanimoto.txt",
            "data/nr/mat/nr_simmat_dc_morgan_rbf.txt",
            "data/nr/mat/nr_simmat_dc_morgan_tanimoto.txt"]
Kv = np.array([np.loadtxt(mat,skiprows=1,usecols=range(1,J+1),delimiter='\t',dtype=np.float32) for mat in simmat_v])

In [3]:
# Nearest neighbors truncation + regularization
def truncate_kernel(K):
    idx = np.argsort(-K,axis=1)
    for i in range(K.shape[0]):
        K[i,idx[i,5:]] = 0
    K += K.T
    K -= (np.real_if_close(np.min(np.linalg.eigvals(K))-0.1))*np.eye(K.shape[0])

for i in range(len(Ku)):
    truncate_kernel(Ku[i])

for i in range(len(Kv)):
    truncate_kernel(Kv[i])

In [4]:
# Load CV folds
folds = []
with open("data/nr/cv/nr_all_folds_cvs1.txt") as f:
    for i in f.readlines():
        rec = i.strip().split(",")
        ln = len(rec)//2
        folds += [[(int(rec[j*2])-1,int(rec[j*2+1])-1) for j in range(ln)]]

In [5]:
# Latent dims and augmented Bernoulli parameter
L  = 12
c  = 3.0

# Insert your favorite neural network here
def nn(Uw1,Vw1):
    return tf.matmul(Uw1,Vw1,transpose_a = True)

In [6]:
# Augmented Bernoulli distribution
#  sampling is not used and therefore omitted

class dAugmentedBernoulli(Distribution):
    def __init__(self,logits,c,obs,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="AugmentedBernoulli"):
        parameters = locals()
        with tf.name_scope(name):
            with tf.control_dependencies([]):
                self._logits = tf.identity(logits)
                self._c      = tf.identity(c)
                self._obs    = tf.identity(obs)
        super(dAugmentedBernoulli,self).__init__(dtype=tf.int32,validate_args=validate_args,allow_nan_stats=allow_nan_stats,
                                                 reparameterization_type=tf.contrib.distributions.NOT_REPARAMETERIZED,
                                                 parameters=parameters,graph_parents=[self._logits,self._c,self._obs],name=name)

    def _log_prob(self,event):
        event = tf.cast(event,tf.float32)
        cond = self._logits >= 0
        neg_abs = tf.where(cond,-self._logits,self._logits)
        sig = ((self._c-1.0)*tf.cast(event,tf.float32)+1.0)*tf.log1p(tf.exp(neg_abs))
        return self._obs * tf.where(cond,(event-1)*self._logits-sig,self._c*event*self._logits-sig)

def __init__(self, *args, **kwargs):
    RandomVariable.__init__(self, *args, **kwargs)
AugmentedBernoulli = type("AugmentedBernoulli", (RandomVariable, dAugmentedBernoulli), {'__init__': __init__})

In [7]:
# Construct VB-MK-LMF model
# Gamma distributions can lead to very noisy gradients so LogNormals are used instead

def construct_model():
    nku = len(Ku)
    nkv = len(Kv)

    obs = tf.placeholder(tf.float32,R_.shape)

    Ug  = TransformedDistribution(distribution=Normal(tf.zeros([nku]),tf.ones([nku])),
                                  bijector=tf.contrib.distributions.bijectors.Exp())
    Vg  = TransformedDistribution(distribution=Normal(tf.zeros([nkv]),tf.ones([nkv])),
                                  bijector=tf.contrib.distributions.bijectors.Exp())

    Ua  = TransformedDistribution(distribution=Normal(tf.zeros([1]),tf.ones([1])),
                                  bijector=tf.contrib.distributions.bijectors.Exp())
    Va  = TransformedDistribution(distribution=Normal(tf.zeros([1]),tf.ones([1])),
                                  bijector=tf.contrib.distributions.bijectors.Exp())

    cKu = tf.cholesky(Ku+tf.eye(I)/Ua) #TODO: rank 1 chol update
    cKv = tf.cholesky(Kv+tf.eye(J)/Va)

    Uw1 = MultivariateNormalTriL(tf.zeros([L,I]),tf.reduce_sum(cKu/tf.reshape(tf.sqrt(Ug),[nku,1,1]),axis=0))
    Vw1 = MultivariateNormalTriL(tf.zeros([L,J]),tf.reduce_sum(cKv/tf.reshape(tf.sqrt(Vg),[nkv,1,1]),axis=0))

    logits = nn(Uw1,Vw1)
    R   = AugmentedBernoulli(logits=logits,c=c,obs=obs,value=tf.cast(logits>0,tf.int32))

    qUg  = TransformedDistribution(distribution=NormalWithSoftplusScale(tf.Variable(tf.zeros([nku])),
                                                                    tf.Variable(tf.ones([nku]))),
                                   bijector=tf.contrib.distributions.bijectors.Exp())
    qVg  = TransformedDistribution(distribution=NormalWithSoftplusScale(tf.Variable(tf.zeros([nkv])),
                                                                        tf.Variable(tf.ones([nkv]))),
                                   bijector=tf.contrib.distributions.bijectors.Exp())
    qUa  = TransformedDistribution(distribution=NormalWithSoftplusScale(tf.Variable(tf.zeros([1])),
                                                                        tf.Variable(tf.ones([1]))),
                                   bijector=tf.contrib.distributions.bijectors.Exp())
    qVa  = TransformedDistribution(distribution=NormalWithSoftplusScale(tf.Variable(tf.zeros([1])),
                                                                        tf.Variable(tf.ones([1]))),
                                   bijector=tf.contrib.distributions.bijectors.Exp())
    qUw1 = MultivariateNormalTriL(tf.Variable(tf.zeros([L,I])),tf.Variable(tf.eye(I)))
    qVw1 = MultivariateNormalTriL(tf.Variable(tf.zeros([L,J])),tf.Variable(tf.eye(J)))
    
    return obs,Ug,Vg,Ua,Va,cKu,cKv,Uw1,Vw1,R,qUg,qVg,qUa,qVa,qUw1,qVw1

In [8]:
auroc_all = []
aupr_all  = []
for f in folds:
    # Edward does not delete nodes so we have to reset the graph manually
    ed.get_session().close()
    tf.reset_default_graph()
    obs,Ug,Vg,Ua,Va,cKu,cKv,Uw1,Vw1,R,qUg,qVg,qUa,qVa,qUw1,qVw1 = construct_model()

    # Hide test examples
    cv = np.zeros((I,J),dtype=np.bool)
    for i in f:
        cv[i[1],i[0]] = True
    data = np.copy(R_)
    data[cv] = 0

    # Construct observation matrix for the augmented Bernoulli distribution
    obs_ = (np.logical_and.outer(np.any(data>0,axis=1),np.any(data>0,axis=0))*1).astype(np.float32)

    # Variational approximation using BBVI
    inference = ed.KLqp({Uw1: qUw1, Vw1: qVw1, Ug: qUg, Vg: qVg, Ua: qUa, Va: qVa},data={R: data, obs: obs_})
    inference.initialize(n_samples=10,n_iter=3000)
    tf.global_variables_initializer().run()
    for _ in range(inference.n_iter):
        info_dict = inference.update()
        inference.print_progress(info_dict)
    inference.finalize()

    # Evaluation
    res = tf.nn.sigmoid(nn(qUw1.mean(),qVw1.mean())**c).eval()

    prc,rec,_ = precision_recall_curve(R_[cv],res[cv])
    fpr,tpr,_ = roc_curve(R_[cv],res[cv])

    auroc = auc(fpr,tpr,reorder=True)
    aupr  = auc(rec,prc,reorder=True)
    auroc_all += [auroc]
    aupr_all  += [aupr]
    print("AUPR: {}\tAUROC: {}".format(aupr,auroc))
print("Overall\nAUPR: {} +- {}, AUROC: {} +- {}".format(np.mean(aupr_all),np.std(aupr_all)*2,np.mean(auroc_all),np.std(auroc_all)*2))


/usr/local/lib/python3.5/dist-packages/edward/util/random_variables.py:52: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  not np.issubdtype(value.dtype, np.float) and \
3000/3000 [100%] ██████████████████████████████ Elapsed: 165s | Loss: 686.201
AUPR: 0.7907644593258405	AUROC: 0.9738461538461538
3000/3000 [100%] ██████████████████████████████ Elapsed: 163s | Loss: 698.946
AUPR: 0.8467848124098125	AUROC: 0.9829545454545454
3000/3000 [100%] ██████████████████████████████ Elapsed: 163s | Loss: 667.156
AUPR: 0.808971088435374	AUROC: 0.9838882921589688
3000/3000 [100%] ██████████████████████████████ Elapsed: 162s | Loss: 646.558
AUPR: 0.9190323884289402	AUROC: 0.9856770833333333
3000/3000 [100%] ██████████████████████████████ Elapsed: 162s | Loss: 666.854
AUPR: 0.8046793292913982	AUROC: 0.97
3000/3000 [100%] ██████████████████████████████ Elapsed: 163s | Loss: 688.001
AUPR: 0.8932687748477223	AUROC: 0.9753846153846154
3000/3000 [100%] ██████████████████████████████ Elapsed: 163s | Loss: 670.754
AUPR: 0.6808335369827305	AUROC: 0.9577114427860698
3000/3000 [100%] ██████████████████████████████ Elapsed: 161s | Loss: 686.461
AUPR: 0.6745142323414259	AUROC: 0.9270568278201865
3000/3000 [100%] ██████████████████████████████ Elapsed: 162s | Loss: 695.615
AUPR: 0.6809694927913402	AUROC: 0.9485553206483439
3000/3000 [100%] ██████████████████████████████ Elapsed: 162s | Loss: 711.265
AUPR: 0.8807234432234432	AUROC: 0.9906152241918665
3000/3000 [100%] ██████████████████████████████ Elapsed: 162s | Loss: 669.910
AUPR: 0.725818862207751	AUROC: 0.9609838846480068
3000/3000 [100%] ██████████████████████████████ Elapsed: 162s | Loss: 695.032
AUPR: 0.9395833333333333	AUROC: 0.9933712121212122
3000/3000 [100%] ██████████████████████████████ Elapsed: 165s | Loss: 697.820
AUPR: 0.7795690915769746	AUROC: 0.9704016913319239
3000/3000 [100%] ██████████████████████████████ Elapsed: 161s | Loss: 667.250
AUPR: 0.7577885261935327	AUROC: 0.9526515151515151
3000/3000 [100%] ██████████████████████████████ Elapsed: 162s | Loss: 619.678
AUPR: 0.8316798352337568	AUROC: 0.9576822916666666
3000/3000 [100%] ██████████████████████████████ Elapsed: 166s | Loss: 701.688
AUPR: 0.8128822790113112	AUROC: 0.9677765843179377
3000/3000 [100%] ██████████████████████████████ Elapsed: 161s | Loss: 701.307
AUPR: 0.7765460729746444	AUROC: 0.9828141783029002
3000/3000 [100%] ██████████████████████████████ Elapsed: 162s | Loss: 664.228
AUPR: 0.7579489930832641	AUROC: 0.9464411557434813
3000/3000 [100%] ██████████████████████████████ Elapsed: 163s | Loss: 687.055
AUPR: 1.0	AUROC: 1.0
3000/3000 [100%] ██████████████████████████████ Elapsed: 168s | Loss: 711.864
AUPR: 0.8903061224489797	AUROC: 0.9937434827945776
3000/3000 [100%] ██████████████████████████████ Elapsed: 161s | Loss: 670.811
AUPR: 0.7846884018759018	AUROC: 0.9820075757575757
3000/3000 [100%] ██████████████████████████████ Elapsed: 162s | Loss: 687.799
AUPR: 0.9082483660130719	AUROC: 0.9676923076923076
3000/3000 [100%] ██████████████████████████████ Elapsed: 162s | Loss: 633.585
AUPR: 0.7273976049750044	AUROC: 0.9753846153846154
3000/3000 [100%] ██████████████████████████████ Elapsed: 162s | Loss: 682.021
AUPR: 0.7707399626517274	AUROC: 0.9602272727272727
3000/3000 [100%] ██████████████████████████████ Elapsed: 171s | Loss: 700.737
AUPR: 0.8490629880564442	AUROC: 0.9830866807610994
3000/3000 [100%] ██████████████████████████████ Elapsed: 162s | Loss: 684.510
AUPR: 0.9380952380952381	AUROC: 0.997037037037037
3000/3000 [100%] ██████████████████████████████ Elapsed: 163s | Loss: 673.411
AUPR: 0.7380120798319328	AUROC: 0.9744318181818181
3000/3000 [100%] ██████████████████████████████ Elapsed: 162s | Loss: 660.716
AUPR: 0.8704351092455932	AUROC: 0.9774489076814659
3000/3000 [100%] ██████████████████████████████ Elapsed: 162s | Loss: 685.691
AUPR: 0.836515185132825	AUROC: 0.9518229166666667
3000/3000 [100%] ██████████████████████████████ Elapsed: 165s | Loss: 705.102
AUPR: 0.7695120367896375	AUROC: 0.9478623566214807
3000/3000 [100%] ██████████████████████████████ Elapsed: 174s | Loss: 697.132
AUPR: 0.5592592592592593	AUROC: 0.9781021897810219
3000/3000 [100%] ██████████████████████████████ Elapsed: 163s | Loss: 644.457
AUPR: 0.7866287094547963	AUROC: 0.9784615384615384
3000/3000 [100%] ██████████████████████████████ Elapsed: 164s | Loss: 668.181
AUPR: 0.5801604278074866	AUROC: 0.9496296296296297
3000/3000 [100%] ██████████████████████████████ Elapsed: 166s | Loss: 708.151
AUPR: 0.8387084053962647	AUROC: 0.9795918367346939
3000/3000 [100%] ██████████████████████████████ Elapsed: 162s | Loss: 664.040
AUPR: 0.7950231481481481	AUROC: 0.9753787878787878
3000/3000 [100%] ██████████████████████████████ Elapsed: 163s | Loss: 725.362
AUPR: 0.9587912087912088	AUROC: 0.9957591178965225
3000/3000 [100%] ██████████████████████████████ Elapsed: 164s | Loss: 703.045
AUPR: 0.7931517345080324	AUROC: 0.9583333333333334
3000/3000 [100%] ██████████████████████████████ Elapsed: 177s | Loss: 671.974
AUPR: 0.7451439763939763	AUROC: 0.9706439393939394
3000/3000 [100%] ██████████████████████████████ Elapsed: 165s | Loss: 644.585
AUPR: 0.8062266057809634	AUROC: 0.9402199904351984
3000/3000 [100%] ██████████████████████████████ Elapsed: 164s | Loss: 682.581
AUPR: 0.786700036075036	AUROC: 0.953125
3000/3000 [100%] ██████████████████████████████ Elapsed: 164s | Loss: 689.783
AUPR: 0.7802578863022942	AUROC: 0.9734848484848485
3000/3000 [100%] ██████████████████████████████ Elapsed: 164s | Loss: 715.697
AUPR: 0.7903615991851286	AUROC: 0.9602577873254565
3000/3000 [100%] ██████████████████████████████ Elapsed: 163s | Loss: 620.093
AUPR: 0.6135608510309467	AUROC: 0.9591261451726568
3000/3000 [100%] ██████████████████████████████ Elapsed: 163s | Loss: 671.031
AUPR: 0.6227494800521116	AUROC: 0.9584615384615384
3000/3000 [100%] ██████████████████████████████ Elapsed: 162s | Loss: 665.176
AUPR: 0.8617065854119426	AUROC: 0.9674479166666667
3000/3000 [100%] ██████████████████████████████ Elapsed: 162s | Loss: 646.275
AUPR: 0.8106610709551886	AUROC: 0.9821882951653944
3000/3000 [100%] ██████████████████████████████ Elapsed: 176s | Loss: 700.085
AUPR: 0.8524305555555556	AUROC: 0.9908088235294118
3000/3000 [100%] ██████████████████████████████ Elapsed: 164s | Loss: 681.846
AUPR: 0.8930596067628154	AUROC: 0.9901338971106413
3000/3000 [100%] ██████████████████████████████ Elapsed: 162s | Loss: 679.086
AUPR: 0.783280522121248	AUROC: 0.9546153846153846
3000/3000 [100%] ██████████████████████████████ Elapsed: 162s | Loss: 689.961
AUPR: 0.818819973130318	AUROC: 0.9733455882352942
Overall
AUPR: 0.7984410657786333 +- 0.18423997840612222, AUROC: 0.9705534515705121 +- 0.032372401574204374

In [ ]: