In [1]:
import numpy as np
import tensorflow as tf

import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def read_triples(path):
    triples = []
    with open(path, 'rt') as f:
        for line in f.readlines():
            s, p, o = line.split()
            triples += [(s.strip(), p.strip(), o.strip())]
    return triples


def unit_cube_projection(var_matrix):
    unit_cube_projection = tf.minimum(1., tf.maximum(var_matrix, 0.))
    return tf.assign(var_matrix, unit_cube_projection)


def make_batches(size, batch_size):
    nb_batch = int(np.ceil(size / float(batch_size)))
    res = [(i * batch_size, min(size, (i + 1) * batch_size)) for i in range(0, nb_batch)]
    return res

class IndexGenerator:
    def __init__(self):
        self.random_state = np.random.RandomState(0)

    def __call__(self, n_samples, candidate_indices):
        shuffled_indices = candidate_indices[self.random_state.permutation(len(candidate_indices))]
        rand_ints = shuffled_indices[np.arange(n_samples) % len(shuffled_indices)]
        return rand_ints

class DistMult:
    def __init__(self, subject_embeddings=None, object_embeddings=None,
                 predicate_embeddings=None,):
        self.subject_embeddings, self.object_embeddings = subject_embeddings, object_embeddings
        self.predicate_embeddings = predicate_embeddings

    def __call__(self):
        scores = tf.reduce_sum(self.subject_embeddings *
                               self.predicate_embeddings *
                               self.object_embeddings, axis=1)
        return scores

In [2]:
entity_embedding_size = 150
predicate_embedding_size = 150

seed = 0
margin = 5

nb_epochs = 1000
nb_batches = 10

np.random.seed(seed)
random_state = np.random.RandomState(seed)
tf.set_random_seed(seed)

dataset_name = 'wn18'

train_triples = read_triples('{}/{}.train.tsv'.format(dataset_name, dataset_name))
valid_triples = read_triples('{}/{}.valid.tsv'.format(dataset_name, dataset_name))
test_triples = read_triples('{}/{}.test.tsv'.format(dataset_name, dataset_name))

In [3]:
all_triples = train_triples + valid_triples + test_triples

entity_set = {s for (s, p, o) in all_triples} | {o for (s, p, o) in all_triples}
predicate_set = {p for (s, p, o) in all_triples}

nb_entities, nb_predicates = len(entity_set), len(predicate_set)
nb_examples = len(train_triples)

entity_to_idx = {entity: idx for idx, entity in enumerate(sorted(entity_set))}
predicate_to_idx = {predicate: idx for idx, predicate in enumerate(sorted(predicate_set))}

entity_embedding_layer = tf.get_variable('entities', shape=[nb_entities, entity_embedding_size],
                                         initializer=tf.contrib.layers.xavier_initializer())

predicate_embedding_layer = tf.get_variable('predicates', shape=[nb_predicates, predicate_embedding_size],
                                            initializer=tf.contrib.layers.xavier_initializer())

subject_inputs = tf.placeholder(tf.int32, shape=[None])
predicate_inputs = tf.placeholder(tf.int32, shape=[None])
object_inputs = tf.placeholder(tf.int32, shape=[None])

target_inputs = tf.placeholder(tf.float32, shape=[None])

subject_embeddings = tf.nn.embedding_lookup(entity_embedding_layer, subject_inputs)
predicate_embeddings = tf.nn.embedding_lookup(predicate_embedding_layer, predicate_inputs)
object_embeddings = tf.nn.embedding_lookup(entity_embedding_layer, object_inputs)

model = DistMult(subject_embeddings=subject_embeddings,
                 predicate_embeddings=predicate_embeddings,
                 object_embeddings=object_embeddings)

scores = model()

In [4]:
import math

hinge_losses = tf.nn.relu(margin - scores * (2 * target_inputs - 1))
loss = tf.reduce_sum(hinge_losses)

optimizer = tf.train.AdagradOptimizer(learning_rate=0.1)
training_step = optimizer.minimize(loss)

projection_step = unit_cube_projection(entity_embedding_layer)

batch_size = math.ceil(nb_examples / nb_batches)
batches = make_batches(nb_examples, batch_size)

nb_versions = 3

Xs = np.array([entity_to_idx[s] for (s, p, o) in train_triples], dtype=np.int32)
Xp = np.array([predicate_to_idx[p] for (s, p, o) in train_triples], dtype=np.int32)
Xo = np.array([entity_to_idx[o] for (s, p, o) in train_triples], dtype=np.int32)

index_gen = IndexGenerator()

init_op = tf.global_variables_initializer()

In [5]:
def stats(values):
    return '{0:.4f} ± {1:.4f}'.format(round(np.mean(values), 4), round(np.std(values), 4))

session = tf.Session()
session.run(init_op)

for epoch in range(1, nb_epochs + 1):
    order = random_state.permutation(nb_examples)
    Xs_shuf, Xp_shuf, Xo_shuf = Xs[order], Xp[order], Xo[order]
    
    loss_values = []

    for batch_no, (batch_start, batch_end) in enumerate(batches):
        curr_batch_size = batch_end - batch_start

        Xs_batch = np.zeros(curr_batch_size * nb_versions, dtype=Xs_shuf.dtype)
        Xp_batch = np.zeros(curr_batch_size * nb_versions, dtype=Xp_shuf.dtype)
        Xo_batch = np.zeros(curr_batch_size * nb_versions, dtype=Xo_shuf.dtype)

        Xs_batch[0::nb_versions] = Xs_shuf[batch_start:batch_end]
        Xp_batch[0::nb_versions] = Xp_shuf[batch_start:batch_end]
        Xo_batch[0::nb_versions] = Xo_shuf[batch_start:batch_end]

        # Xs_batch[1::nb_versions] needs to be corrupted
        Xs_batch[1::nb_versions] = index_gen(curr_batch_size, np.arange(nb_entities))
        Xp_batch[1::nb_versions] = Xp_shuf[batch_start:batch_end]
        Xo_batch[1::nb_versions] = Xo_shuf[batch_start:batch_end]

        # Xo_batch[2::nb_versions] needs to be corrupted
        Xs_batch[2::nb_versions] = Xs_shuf[batch_start:batch_end]
        Xp_batch[2::nb_versions] = Xp_shuf[batch_start:batch_end]
        Xo_batch[2::nb_versions] = index_gen(curr_batch_size, np.arange(nb_entities))

        feed_dict = {
            subject_inputs: Xs_batch, predicate_inputs: Xp_batch, object_inputs: Xo_batch,
            target_inputs: np.array([1.0, 0.0, 0.0] * curr_batch_size)
        }

        _, loss_value = session.run([training_step, loss], feed_dict=feed_dict)
        session.run(projection_step)

        loss_values += [loss_value / (Xp_batch.shape[0] / nb_versions)]
    
    if epoch % 10 == 0:
        logger.info('Epoch {0}\tLoss value: {1}'.format(epoch, stats(loss_values)))


INFO:__main__:Epoch 10	Loss value: 5.4755 ± 0.1380
INFO:__main__:Epoch 20	Loss value: 0.2603 ± 0.0061
INFO:__main__:Epoch 30	Loss value: 0.0730 ± 0.0038
INFO:__main__:Epoch 40	Loss value: 0.0492 ± 0.0039
INFO:__main__:Epoch 50	Loss value: 0.0430 ± 0.0032
INFO:__main__:Epoch 60	Loss value: 0.0363 ± 0.0037
INFO:__main__:Epoch 70	Loss value: 0.0326 ± 0.0031
INFO:__main__:Epoch 80	Loss value: 0.0317 ± 0.0043
INFO:__main__:Epoch 90	Loss value: 0.0275 ± 0.0033
INFO:__main__:Epoch 100	Loss value: 0.0259 ± 0.0034
INFO:__main__:Epoch 110	Loss value: 0.0246 ± 0.0043
INFO:__main__:Epoch 120	Loss value: 0.0243 ± 0.0044
INFO:__main__:Epoch 130	Loss value: 0.0214 ± 0.0025
INFO:__main__:Epoch 140	Loss value: 0.0233 ± 0.0037
INFO:__main__:Epoch 150	Loss value: 0.0218 ± 0.0032
INFO:__main__:Epoch 160	Loss value: 0.0204 ± 0.0027
INFO:__main__:Epoch 170	Loss value: 0.0197 ± 0.0030
INFO:__main__:Epoch 180	Loss value: 0.0201 ± 0.0037
INFO:__main__:Epoch 190	Loss value: 0.0200 ± 0.0036
INFO:__main__:Epoch 200	Loss value: 0.0210 ± 0.0027
INFO:__main__:Epoch 210	Loss value: 0.0206 ± 0.0034
INFO:__main__:Epoch 220	Loss value: 0.0195 ± 0.0036
INFO:__main__:Epoch 230	Loss value: 0.0212 ± 0.0028
INFO:__main__:Epoch 240	Loss value: 0.0188 ± 0.0037
INFO:__main__:Epoch 250	Loss value: 0.0214 ± 0.0032
INFO:__main__:Epoch 260	Loss value: 0.0199 ± 0.0033
INFO:__main__:Epoch 270	Loss value: 0.0172 ± 0.0031
INFO:__main__:Epoch 280	Loss value: 0.0166 ± 0.0020
INFO:__main__:Epoch 290	Loss value: 0.0177 ± 0.0035
INFO:__main__:Epoch 300	Loss value: 0.0197 ± 0.0032
INFO:__main__:Epoch 310	Loss value: 0.0178 ± 0.0031
INFO:__main__:Epoch 320	Loss value: 0.0168 ± 0.0038
INFO:__main__:Epoch 330	Loss value: 0.0161 ± 0.0037
INFO:__main__:Epoch 340	Loss value: 0.0154 ± 0.0036
INFO:__main__:Epoch 350	Loss value: 0.0169 ± 0.0032
INFO:__main__:Epoch 360	Loss value: 0.0162 ± 0.0027
INFO:__main__:Epoch 370	Loss value: 0.0156 ± 0.0029
INFO:__main__:Epoch 380	Loss value: 0.0167 ± 0.0034
INFO:__main__:Epoch 390	Loss value: 0.0156 ± 0.0033
INFO:__main__:Epoch 400	Loss value: 0.0173 ± 0.0031
INFO:__main__:Epoch 410	Loss value: 0.0148 ± 0.0031
INFO:__main__:Epoch 420	Loss value: 0.0155 ± 0.0031
INFO:__main__:Epoch 430	Loss value: 0.0165 ± 0.0029
INFO:__main__:Epoch 440	Loss value: 0.0159 ± 0.0030
INFO:__main__:Epoch 450	Loss value: 0.0163 ± 0.0023
INFO:__main__:Epoch 460	Loss value: 0.0166 ± 0.0038
INFO:__main__:Epoch 470	Loss value: 0.0144 ± 0.0022
INFO:__main__:Epoch 480	Loss value: 0.0160 ± 0.0040
INFO:__main__:Epoch 490	Loss value: 0.0151 ± 0.0036
INFO:__main__:Epoch 500	Loss value: 0.0139 ± 0.0027
INFO:__main__:Epoch 510	Loss value: 0.0160 ± 0.0019
INFO:__main__:Epoch 520	Loss value: 0.0144 ± 0.0025
INFO:__main__:Epoch 530	Loss value: 0.0146 ± 0.0021
INFO:__main__:Epoch 540	Loss value: 0.0139 ± 0.0022
INFO:__main__:Epoch 550	Loss value: 0.0155 ± 0.0031
INFO:__main__:Epoch 560	Loss value: 0.0154 ± 0.0030
INFO:__main__:Epoch 570	Loss value: 0.0160 ± 0.0028
INFO:__main__:Epoch 580	Loss value: 0.0160 ± 0.0033
INFO:__main__:Epoch 590	Loss value: 0.0151 ± 0.0025
INFO:__main__:Epoch 600	Loss value: 0.0130 ± 0.0036
INFO:__main__:Epoch 610	Loss value: 0.0158 ± 0.0026
INFO:__main__:Epoch 620	Loss value: 0.0134 ± 0.0025
INFO:__main__:Epoch 630	Loss value: 0.0152 ± 0.0029
INFO:__main__:Epoch 640	Loss value: 0.0128 ± 0.0033
INFO:__main__:Epoch 650	Loss value: 0.0145 ± 0.0026
INFO:__main__:Epoch 660	Loss value: 0.0138 ± 0.0027
INFO:__main__:Epoch 670	Loss value: 0.0141 ± 0.0022
INFO:__main__:Epoch 680	Loss value: 0.0144 ± 0.0021
INFO:__main__:Epoch 690	Loss value: 0.0134 ± 0.0017
INFO:__main__:Epoch 700	Loss value: 0.0127 ± 0.0027
INFO:__main__:Epoch 710	Loss value: 0.0148 ± 0.0027
INFO:__main__:Epoch 720	Loss value: 0.0140 ± 0.0022
INFO:__main__:Epoch 730	Loss value: 0.0145 ± 0.0027
INFO:__main__:Epoch 740	Loss value: 0.0135 ± 0.0031
INFO:__main__:Epoch 750	Loss value: 0.0127 ± 0.0034
INFO:__main__:Epoch 760	Loss value: 0.0161 ± 0.0049
INFO:__main__:Epoch 770	Loss value: 0.0125 ± 0.0035
INFO:__main__:Epoch 780	Loss value: 0.0130 ± 0.0024
INFO:__main__:Epoch 790	Loss value: 0.0150 ± 0.0027
INFO:__main__:Epoch 800	Loss value: 0.0144 ± 0.0035
INFO:__main__:Epoch 810	Loss value: 0.0144 ± 0.0024
INFO:__main__:Epoch 820	Loss value: 0.0133 ± 0.0028
INFO:__main__:Epoch 830	Loss value: 0.0148 ± 0.0015
INFO:__main__:Epoch 840	Loss value: 0.0147 ± 0.0033
INFO:__main__:Epoch 850	Loss value: 0.0132 ± 0.0024
INFO:__main__:Epoch 860	Loss value: 0.0130 ± 0.0024
INFO:__main__:Epoch 870	Loss value: 0.0135 ± 0.0027
INFO:__main__:Epoch 880	Loss value: 0.0147 ± 0.0013
INFO:__main__:Epoch 890	Loss value: 0.0162 ± 0.0038
INFO:__main__:Epoch 900	Loss value: 0.0129 ± 0.0037
INFO:__main__:Epoch 910	Loss value: 0.0131 ± 0.0022
INFO:__main__:Epoch 920	Loss value: 0.0134 ± 0.0017
INFO:__main__:Epoch 930	Loss value: 0.0127 ± 0.0021
INFO:__main__:Epoch 940	Loss value: 0.0127 ± 0.0032
INFO:__main__:Epoch 950	Loss value: 0.0141 ± 0.0023
INFO:__main__:Epoch 960	Loss value: 0.0123 ± 0.0033
INFO:__main__:Epoch 970	Loss value: 0.0133 ± 0.0026
INFO:__main__:Epoch 980	Loss value: 0.0121 ± 0.0024
INFO:__main__:Epoch 990	Loss value: 0.0152 ± 0.0016
INFO:__main__:Epoch 1000	Loss value: 0.0130 ± 0.0023

In [6]:
for eval_name, eval_triples in [('valid', valid_triples), ('test', test_triples)]:

    ranks_subj, ranks_obj = [], []
    filtered_ranks_subj, filtered_ranks_obj = [], []

    for _i, (s, p, o) in enumerate(eval_triples):
        s_idx, p_idx, o_idx = entity_to_idx[s], predicate_to_idx[p], entity_to_idx[o]

        Xs = np.full(shape=(nb_entities,), fill_value=s_idx, dtype=np.int32)
        Xp = np.full(shape=(nb_entities,), fill_value=p_idx, dtype=np.int32)
        Xo = np.full(shape=(nb_entities,), fill_value=o_idx, dtype=np.int32)

        feed_dict_corrupt_subj = {subject_inputs: np.arange(nb_entities), predicate_inputs: Xp, object_inputs: Xo}
        feed_dict_corrupt_obj = {subject_inputs: Xs, predicate_inputs: Xp, object_inputs: np.arange(nb_entities)}

        # scores of (1, p, o), (2, p, o), .., (N, p, o)
        scores_subj = session.run(scores, feed_dict=feed_dict_corrupt_subj)

        # scores of (s, p, 1), (s, p, 2), .., (s, p, N)
        scores_obj = session.run(scores, feed_dict=feed_dict_corrupt_obj)

        ranks_subj += [1 + np.sum(scores_subj > scores_subj[s_idx])]
        ranks_obj += [1 + np.sum(scores_obj > scores_obj[o_idx])]

        filtered_scores_subj = scores_subj.copy()
        filtered_scores_obj = scores_obj.copy()

        rm_idx_s = [entity_to_idx[fs] for (fs, fp, fo) in all_triples if fs != s and fp == p and fo == o]
        rm_idx_o = [entity_to_idx[fo] for (fs, fp, fo) in all_triples if fs == s and fp == p and fo != o]

        filtered_scores_subj[rm_idx_s] = - np.inf
        filtered_scores_obj[rm_idx_o] = - np.inf

        filtered_ranks_subj += [1 + np.sum(filtered_scores_subj > filtered_scores_subj[s_idx])]
        filtered_ranks_obj += [1 + np.sum(filtered_scores_obj > filtered_scores_obj[o_idx])]

        if _i % 1000 == 0:
            logger.info('{}/{} ..'.format(_i, len(eval_triples)))
        
        
    ranks = ranks_subj + ranks_obj
    filtered_ranks = filtered_ranks_subj + filtered_ranks_obj

    for setting_name, setting_ranks in [('Raw', ranks), ('Filtered', filtered_ranks)]:
        mean_rank = np.mean(setting_ranks)
        logger.info('[{}] {} Mean Rank: {}'.format(eval_name, setting_name, mean_rank))
        for k in [1, 3, 5, 10]:
            hits_at_k = np.mean(np.asarray(setting_ranks) <= k) * 100
            logger.info('[{}] {} Hits@{}: {}'.format(eval_name, setting_name, k, hits_at_k))


INFO:__main__:0/5000 ..
INFO:__main__:1000/5000 ..
INFO:__main__:2000/5000 ..
INFO:__main__:3000/5000 ..
INFO:__main__:4000/5000 ..
INFO:__main__:[valid] Raw Mean Rank: 822.6784
INFO:__main__:[valid] Raw Hits@1: 41.349999999999994
INFO:__main__:[valid] Raw Hits@3: 66.85
INFO:__main__:[valid] Raw Hits@5: 75.14
INFO:__main__:[valid] Raw Hits@10: 81.97
INFO:__main__:[valid] Filtered Mean Rank: 812.4296
INFO:__main__:[valid] Filtered Hits@1: 66.9
INFO:__main__:[valid] Filtered Hits@3: 90.64999999999999
INFO:__main__:[valid] Filtered Hits@5: 93.0
INFO:__main__:[valid] Filtered Hits@10: 94.07
INFO:__main__:0/5000 ..
INFO:__main__:1000/5000 ..
INFO:__main__:2000/5000 ..
INFO:__main__:3000/5000 ..
INFO:__main__:4000/5000 ..
INFO:__main__:[test] Raw Mean Rank: 857.8714
INFO:__main__:[test] Raw Hits@1: 41.06
INFO:__main__:[test] Raw Hits@3: 67.44
INFO:__main__:[test] Raw Hits@5: 75.12
INFO:__main__:[test] Raw Hits@10: 82.43
INFO:__main__:[test] Filtered Mean Rank: 847.4885
INFO:__main__:[test] Filtered Hits@1: 66.09
INFO:__main__:[test] Filtered Hits@3: 90.67
INFO:__main__:[test] Filtered Hits@5: 93.10000000000001
INFO:__main__:[test] Filtered Hits@10: 94.19999999999999

In [ ]: