In [1]:
import numpy as np
import tensorflow as tf

import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def read_triples(path):
    triples = []
    with open(path, 'rt') as f:
        for line in f.readlines():
            s, p, o = line.split()
            triples += [(s.strip(), p.strip(), o.strip())]
    return triples


def unit_cube_projection(var_matrix):
    unit_cube_projection = tf.minimum(1., tf.maximum(var_matrix, 0.))
    return tf.assign(var_matrix, unit_cube_projection)


def make_batches(size, batch_size):
    nb_batch = int(np.ceil(size / float(batch_size)))
    res = [(i * batch_size, min(size, (i + 1) * batch_size)) for i in range(0, nb_batch)]
    return res

class IndexGenerator:
    def __init__(self):
        self.random_state = np.random.RandomState(0)

    def __call__(self, n_samples, candidate_indices):
        shuffled_indices = candidate_indices[self.random_state.permutation(len(candidate_indices))]
        rand_ints = shuffled_indices[np.arange(n_samples) % len(shuffled_indices)]
        return rand_ints

In [2]:
class DistMult:
    def __init__(self, subject_embeddings=None, object_embeddings=None,
                 predicate_embeddings=None,):
        self.subject_embeddings, self.object_embeddings = subject_embeddings, object_embeddings
        self.predicate_embeddings = predicate_embeddings

    def __call__(self):
        scores = tf.reduce_sum(self.subject_embeddings *
                               self.predicate_embeddings *
                               self.object_embeddings, axis=1)
        return scores

class ComplEx:
    def __init__(self, subject_embeddings=None, object_embeddings=None,
                 predicate_embeddings=None,):
        self.subject_embeddings, self.object_embeddings = subject_embeddings, object_embeddings
        self.predicate_embeddings = predicate_embeddings

    def __call__(self):
        es_re, es_im = tf.split(value=self.subject_embeddings, num_or_size_splits=2, axis=1)
        eo_re, eo_im = tf.split(value=self.object_embeddings, num_or_size_splits=2, axis=1)
        ew_re, ew_im = tf.split(value=self.predicate_embeddings, num_or_size_splits=2, axis=1)

        def dot3(arg1, rel, arg2):
            return tf.reduce_sum(arg1 * rel * arg2, axis=1)

        scores = dot3(es_re, ew_re, eo_re) + dot3(es_re, ew_im, eo_im) + dot3(es_im, ew_re, eo_im) - dot3(es_im, ew_im, eo_re)
        return scores

In [3]:
entity_embedding_size = 150
predicate_embedding_size = 150

seed = 0
margin = 5

nb_epochs = 100

nb_discriminator_epochs = 1
nb_adversary_epochs = 10

nb_batches = 10

violation_loss_weight = 0.1
adversary_batch_size = 5

np.random.seed(seed)
random_state = np.random.RandomState(seed)
tf.set_random_seed(seed)

dataset_name = 'fb122'

train_triples = read_triples('{}/{}.train.tsv'.format(dataset_name, dataset_name))
valid_triples = read_triples('{}/{}.valid.tsv'.format(dataset_name, dataset_name))
test_triples = read_triples('{}/{}.test.tsv'.format(dataset_name, dataset_name))

from parse import parse_clause
with open('{}/{}-clauses.pl'.format(dataset_name, dataset_name), 'rt') as f:
    lines = f.readlines()

clauses = [parse_clause(line.strip()) for line in lines]

In [4]:
all_triples = train_triples + valid_triples + test_triples

entity_set = {s for (s, p, o) in all_triples} | {o for (s, p, o) in all_triples}
predicate_set = {p for (s, p, o) in all_triples}

nb_entities, nb_predicates = len(entity_set), len(predicate_set)
nb_examples = len(train_triples)

entity_to_idx = {entity: idx for idx, entity in enumerate(sorted(entity_set))}
predicate_to_idx = {predicate: idx for idx, predicate in enumerate(sorted(predicate_set))}

entity_embedding_layer = tf.get_variable('entities', shape=[nb_entities, entity_embedding_size],
                                         initializer=tf.contrib.layers.xavier_initializer())

predicate_embedding_layer = tf.get_variable('predicates', shape=[nb_predicates, predicate_embedding_size],
                                            initializer=tf.contrib.layers.xavier_initializer())

subject_inputs = tf.placeholder(tf.int32, shape=[None])
predicate_inputs = tf.placeholder(tf.int32, shape=[None])
object_inputs = tf.placeholder(tf.int32, shape=[None])

target_inputs = tf.placeholder(tf.float32, shape=[None])

subject_embeddings = tf.nn.embedding_lookup(entity_embedding_layer, subject_inputs)
predicate_embeddings = tf.nn.embedding_lookup(predicate_embedding_layer, predicate_inputs)
object_embeddings = tf.nn.embedding_lookup(entity_embedding_layer, object_inputs)

model_parameters = {
    'subject_embeddings': subject_embeddings,
    'predicate_embeddings': predicate_embeddings,
    'object_embeddings': object_embeddings
}

model_class = ComplEx
model = model_class(**model_parameters)

scores = model()

In [5]:
class Adversary:
    """
    Utility class for, given a set of clauses, computing the symbolic violation loss.
    """

    def __init__(self, clauses, predicate_to_index,
                 entity_embedding_layer, predicate_embedding_layer,
                 model_class, model_parameters, loss_margin=0.0, batch_size=1):

        self.clauses, self.predicate_to_index = clauses, predicate_to_index
        self.entity_embedding_layer = entity_embedding_layer
        self.predicate_embedding_layer = predicate_embedding_layer

        self.entity_embedding_size = self.entity_embedding_layer.get_shape()[-1].value

        self.model_class, self.model_parameters = model_class, model_parameters
        self.batch_size = batch_size

        def _violation_losses(body_scores, head_scores, margin):
            _losses = tf.nn.relu(margin - head_scores + body_scores)
            return tf.reduce_max(_losses)

        self.loss_function = lambda body_scores, head_scores:\
            _violation_losses(body_scores, head_scores, margin=loss_margin)

        # Symbolic functions computing the continuous loss
        self.loss = tf.constant(.0)

        # Trainable parameters of the adversarial model
        self.parameters = []

        # Mapping {clause:v2l} where "clause" is a clause, and v2l is a {var_name:layer} mapping
        self.clause_to_variable_name_to_layer = dict()
        self.clause_to_loss = dict()

        for clause_idx, clause in enumerate(clauses):
            clause_loss, clause_parameters, variable_name_to_layer =\
                self._parse_clause('clause_{}'.format(clause_idx), clause)

            self.clause_to_variable_name_to_layer[clause] = variable_name_to_layer
            self.clause_to_loss[clause] = clause_loss

            self.loss += clause_loss
            self.parameters += clause_parameters

    def _parse_atom(self, atom, variable_name_to_layer):
        """
        Given an atom in the form p(X, Y), where X and Y are associated to two distinct [1, k] embedding layers,
        return the symbolic score of the atom.
        """
        predicate_idx = self.predicate_to_index[atom.predicate.name]
        
        # [batch_size x 1 x embedding_size] tensor
        predicate_embeddings = tf.nn.embedding_lookup(self.predicate_embedding_layer, [predicate_idx] * self.batch_size)
        arg1_name, arg2_name = atom.arguments[0].name, atom.arguments[1].name

        # [batch_size x embedding_size] variables
        arg1_layer, arg2_layer = variable_name_to_layer[arg1_name], variable_name_to_layer[arg2_name]

        subject_embeddings = variable_name_to_layer[arg1_name]
        object_embeddings = variable_name_to_layer[arg2_name]

        model_parameters = self.model_parameters
        
        model_parameters['subject_embeddings'] = subject_embeddings
        model_parameters['object_embeddings'] = object_embeddings
        
        model_parameters['predicate_embeddings'] = predicate_embeddings

        scoring_model = self.model_class(**model_parameters)
        atom_score = scoring_model()

        return atom_score

    def _parse_conjunction(self, atoms, variable_name_to_layer):
        """
        Given a conjunction of atoms in the form p(X0, X1), q(X2, X3), r(X4, X5), return its symbolic score.
        """
        conjunction_score = None
        for atom in atoms:
            atom_score = self._parse_atom(atom, variable_name_to_layer=variable_name_to_layer)
            conjunction_score = atom_score if conjunction_score is None else tf.minimum(conjunction_score, atom_score)
        return conjunction_score

    def _parse_clause(self, name, clause):
        """
        Given a clause in the form p(X0, X1) :- q(X2, X3), r(X4, X5), return its symbolic score.
        """
        head, body = clause.head, clause.body

        # Enumerate all variables
        variable_names = {argument.name for argument in head.arguments}
        for body_atom in body:
            variable_names |= {argument.name for argument in body_atom.arguments}

        # Instantiate a new layer for each variable
        variable_name_to_layer = dict()
        for variable_name in sorted(variable_names):
            # [batch_size, embedding_size] variable
            variable_layer = tf.get_variable('{}_{}_violator'.format(name, variable_name),
                                             shape=[self.batch_size, self.entity_embedding_size],
                                             initializer=tf.contrib.layers.xavier_initializer())
            variable_name_to_layer[variable_name] = variable_layer

        head_score = self._parse_atom(head, variable_name_to_layer=variable_name_to_layer)
        body_score = self._parse_conjunction(body, variable_name_to_layer=variable_name_to_layer)

        parameters = [variable_name_to_layer[variable_name] for variable_name in sorted(variable_names)]
        loss = self.loss_function(body_score, head_score)
        return loss, parameters, variable_name_to_layer

In [ ]:
adversary = Adversary(clauses=clauses, predicate_to_index=predicate_to_idx,
                      entity_embedding_layer=entity_embedding_layer,
                      predicate_embedding_layer=predicate_embedding_layer,
                      model_class=model_class, model_parameters=model_parameters,
                      batch_size=adversary_batch_size)

adversary_init_op = tf.variables_initializer(var_list=adversary.parameters, name='init_adversary')
violation_loss = adversary.loss

ADVERSARIAL_OPTIMIZER_SCOPE_NAME = 'adversary/optimizer'
with tf.variable_scope(ADVERSARIAL_OPTIMIZER_SCOPE_NAME):
    adversarial_optimizer = tf.train.AdagradOptimizer(learning_rate=0.1)
    adversarial_training_step = adversarial_optimizer.minimize(- violation_loss, var_list=adversary.parameters)

adversary_optimizer_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=ADVERSARIAL_OPTIMIZER_SCOPE_NAME)
adversary_optimizer_vars_init_op = tf.variables_initializer(adversary_optimizer_vars)

adversary_projections = [unit_cube_projection(emb) for emb in adversary.parameters]


hinge_losses = tf.nn.relu(margin - scores * (2 * target_inputs - 1))

loss = tf.reduce_sum(hinge_losses) + violation_loss_weight * violation_loss

optimizer = tf.train.AdagradOptimizer(learning_rate=0.1)
training_step = optimizer.minimize(loss, var_list=[entity_embedding_layer, predicate_embedding_layer])

projection_step = unit_cube_projection(entity_embedding_layer)


import math
batch_size = math.ceil(nb_examples / nb_batches)
batches = make_batches(nb_examples, batch_size)

nb_versions = 3

Xs = np.array([entity_to_idx[s] for (s, p, o) in train_triples], dtype=np.int32)
Xp = np.array([predicate_to_idx[p] for (s, p, o) in train_triples], dtype=np.int32)
Xo = np.array([entity_to_idx[o] for (s, p, o) in train_triples], dtype=np.int32)

index_gen = IndexGenerator()

init_op = tf.global_variables_initializer()

In [ ]:
def stats(values):
    return '{0:.4f} ± {1:.4f}'.format(round(np.mean(values), 4), round(np.std(values), 4))

session = tf.Session()
session.run(init_op)

for epoch in range(1, nb_epochs + 1):


    for discriminator_epoch in range(1, nb_discriminator_epochs + 1):
        order = random_state.permutation(nb_examples)
        Xs_shuf, Xp_shuf, Xo_shuf = Xs[order], Xp[order], Xo[order]

        loss_values = []

        for batch_no, (batch_start, batch_end) in enumerate(batches):
            curr_batch_size = batch_end - batch_start

            Xs_batch = np.zeros(curr_batch_size * nb_versions, dtype=Xs_shuf.dtype)
            Xp_batch = np.zeros(curr_batch_size * nb_versions, dtype=Xp_shuf.dtype)
            Xo_batch = np.zeros(curr_batch_size * nb_versions, dtype=Xo_shuf.dtype)

            Xs_batch[0::nb_versions] = Xs_shuf[batch_start:batch_end]
            Xp_batch[0::nb_versions] = Xp_shuf[batch_start:batch_end]
            Xo_batch[0::nb_versions] = Xo_shuf[batch_start:batch_end]

            # Xs_batch[1::nb_versions] needs to be corrupted
            Xs_batch[1::nb_versions] = index_gen(curr_batch_size, np.arange(nb_entities))
            Xp_batch[1::nb_versions] = Xp_shuf[batch_start:batch_end]
            Xo_batch[1::nb_versions] = Xo_shuf[batch_start:batch_end]

            # Xo_batch[2::nb_versions] needs to be corrupted
            Xs_batch[2::nb_versions] = Xs_shuf[batch_start:batch_end]
            Xp_batch[2::nb_versions] = Xp_shuf[batch_start:batch_end]
            Xo_batch[2::nb_versions] = index_gen(curr_batch_size, np.arange(nb_entities))

            feed_dict = {
                subject_inputs: Xs_batch, predicate_inputs: Xp_batch, object_inputs: Xo_batch,
                target_inputs: np.array([1.0, 0.0, 0.0] * curr_batch_size)
            }

            _, loss_value = session.run([training_step, loss], feed_dict=feed_dict)
            session.run(projection_step)

            loss_values += [loss_value / (Xp_batch.shape[0] / nb_versions)]

        logger.info('Epoch {0}/{1}\tLoss value: {2}'.format(epoch, discriminator_epoch, stats(loss_values)))


    session.run([adversary_init_op, adversary_optimizer_vars_init_op])
    entity_indices = np.array(sorted(entity_to_idx.values()))

    def ground_init_op(adversarial_embeddings):
        rnd_entity_indices = entity_indices[
            random_state.randint(low=0,
                                 high=len(entity_indices),
                                 size=adversary_batch_size)]
        entity_embeddings = tf.nn.embedding_lookup(entity_embedding_layer, rnd_entity_indices)
        return adversarial_embeddings.assign(entity_embeddings)

    assignment_ops = [ground_init_op(emb) for emb in adversary.parameters]
    session.run(assignment_ops)

    for adversary_epoch in range(1, nb_adversary_epochs + 1):
        _, violation_loss_value = session.run([adversarial_training_step, violation_loss])
        logger.info('Epoch {0}/{1}\tLoss value: {2}'.format(epoch, adversary_epoch, violation_loss_value))

        session.run(adversary_projections)


INFO:__main__:Epoch 1/1	Loss value: 14.9996 ± 0.0005
INFO:__main__:Epoch 1/1	Loss value: 0.037005141377449036
INFO:__main__:Epoch 1/2	Loss value: 0.05832251161336899
INFO:__main__:Epoch 1/3	Loss value: 0.08015210926532745
INFO:__main__:Epoch 1/4	Loss value: 0.10290764272212982
INFO:__main__:Epoch 1/5	Loss value: 0.12751871347427368
INFO:__main__:Epoch 1/6	Loss value: 0.15434563159942627
INFO:__main__:Epoch 1/7	Loss value: 0.18477438390254974
INFO:__main__:Epoch 1/8	Loss value: 0.2192409336566925
INFO:__main__:Epoch 1/9	Loss value: 0.25854042172431946
INFO:__main__:Epoch 1/10	Loss value: 0.3039083778858185
INFO:__main__:Epoch 2/1	Loss value: 14.6677 ± 0.4112
INFO:__main__:Epoch 2/1	Loss value: 27.937274932861328
INFO:__main__:Epoch 2/2	Loss value: 47.388267517089844
INFO:__main__:Epoch 2/3	Loss value: 69.58460235595703
INFO:__main__:Epoch 2/4	Loss value: 95.48681640625
INFO:__main__:Epoch 2/5	Loss value: 125.84796142578125
INFO:__main__:Epoch 2/6	Loss value: 160.33551025390625
INFO:__main__:Epoch 2/7	Loss value: 199.41590881347656
INFO:__main__:Epoch 2/8	Loss value: 242.7416229248047
INFO:__main__:Epoch 2/9	Loss value: 290.0210876464844
INFO:__main__:Epoch 2/10	Loss value: 341.21929931640625
INFO:__main__:Epoch 3/1	Loss value: 10.5054 ± 1.8336
INFO:__main__:Epoch 3/1	Loss value: 67.00949096679688
INFO:__main__:Epoch 3/2	Loss value: 138.52745056152344
INFO:__main__:Epoch 3/3	Loss value: 212.14288330078125
INFO:__main__:Epoch 3/4	Loss value: 289.3619384765625
INFO:__main__:Epoch 3/5	Loss value: 374.4559631347656
INFO:__main__:Epoch 3/6	Loss value: 464.1065368652344
INFO:__main__:Epoch 3/7	Loss value: 558.7449951171875
INFO:__main__:Epoch 3/8	Loss value: 658.7227783203125
INFO:__main__:Epoch 3/9	Loss value: 761.2094116210938
INFO:__main__:Epoch 3/10	Loss value: 865.8599243164062
INFO:__main__:Epoch 4/1	Loss value: 6.5857 ± 0.3513
INFO:__main__:Epoch 4/1	Loss value: 54.379486083984375
INFO:__main__:Epoch 4/2	Loss value: 147.379150390625
INFO:__main__:Epoch 4/3	Loss value: 244.63243103027344
INFO:__main__:Epoch 4/4	Loss value: 336.07159423828125
INFO:__main__:Epoch 4/5	Loss value: 436.36920166015625
INFO:__main__:Epoch 4/6	Loss value: 540.1325073242188
INFO:__main__:Epoch 4/7	Loss value: 646.1339111328125
INFO:__main__:Epoch 4/8	Loss value: 754.5147094726562
INFO:__main__:Epoch 4/9	Loss value: 864.042236328125
INFO:__main__:Epoch 4/10	Loss value: 971.8656616210938
INFO:__main__:Epoch 5/1	Loss value: 5.3929 ± 0.1513
INFO:__main__:Epoch 5/1	Loss value: 60.66802978515625
INFO:__main__:Epoch 5/2	Loss value: 169.662841796875
INFO:__main__:Epoch 5/3	Loss value: 279.13446044921875
INFO:__main__:Epoch 5/4	Loss value: 387.0870361328125
INFO:__main__:Epoch 5/5	Loss value: 504.5575256347656
INFO:__main__:Epoch 5/6	Loss value: 622.1484375
INFO:__main__:Epoch 5/7	Loss value: 745.43359375
INFO:__main__:Epoch 5/8	Loss value: 866.3714599609375
INFO:__main__:Epoch 5/9	Loss value: 987.5112915039062
INFO:__main__:Epoch 5/10	Loss value: 1104.11474609375
INFO:__main__:Epoch 6/1	Loss value: 4.5742 ± 0.0966
INFO:__main__:Epoch 6/1	Loss value: 71.15213012695312
INFO:__main__:Epoch 6/2	Loss value: 188.0497589111328
INFO:__main__:Epoch 6/3	Loss value: 307.7999572753906
INFO:__main__:Epoch 6/4	Loss value: 421.56268310546875
INFO:__main__:Epoch 6/5	Loss value: 538.3877563476562
INFO:__main__:Epoch 6/6	Loss value: 658.9730834960938
INFO:__main__:Epoch 6/7	Loss value: 780.6898193359375
INFO:__main__:Epoch 6/8	Loss value: 903.0748291015625
INFO:__main__:Epoch 6/9	Loss value: 1023.8822021484375
INFO:__main__:Epoch 6/10	Loss value: 1139.8067626953125
INFO:__main__:Epoch 7/1	Loss value: 3.9783 ± 0.0802
INFO:__main__:Epoch 7/1	Loss value: 108.32939147949219
INFO:__main__:Epoch 7/2	Loss value: 259.8048095703125
INFO:__main__:Epoch 7/3	Loss value: 397.495361328125
INFO:__main__:Epoch 7/4	Loss value: 530.4791870117188
INFO:__main__:Epoch 7/5	Loss value: 662.435546875
INFO:__main__:Epoch 7/6	Loss value: 797.0531616210938
INFO:__main__:Epoch 7/7	Loss value: 930.1895751953125
INFO:__main__:Epoch 7/8	Loss value: 1056.8377685546875
INFO:__main__:Epoch 7/9	Loss value: 1186.31103515625
INFO:__main__:Epoch 7/10	Loss value: 1302.7999267578125
INFO:__main__:Epoch 8/1	Loss value: 3.5289 ± 0.0313
INFO:__main__:Epoch 8/1	Loss value: 86.37255859375
INFO:__main__:Epoch 8/2	Loss value: 245.46218872070312
INFO:__main__:Epoch 8/3	Loss value: 376.4599609375
INFO:__main__:Epoch 8/4	Loss value: 517.5479125976562
INFO:__main__:Epoch 8/5	Loss value: 639.9364624023438
INFO:__main__:Epoch 8/6	Loss value: 774.43798828125
INFO:__main__:Epoch 8/7	Loss value: 899.12744140625
INFO:__main__:Epoch 8/8	Loss value: 1023.1798706054688
INFO:__main__:Epoch 8/9	Loss value: 1141.3125
INFO:__main__:Epoch 8/10	Loss value: 1254.4583740234375
INFO:__main__:Epoch 9/1	Loss value: 3.1341 ± 0.0297
INFO:__main__:Epoch 9/1	Loss value: 78.77142333984375
INFO:__main__:Epoch 9/2	Loss value: 216.1274871826172
INFO:__main__:Epoch 9/3	Loss value: 350.9869384765625
INFO:__main__:Epoch 9/4	Loss value: 471.5009460449219
INFO:__main__:Epoch 9/5	Loss value: 595.7521362304688
INFO:__main__:Epoch 9/6	Loss value: 713.3211059570312
INFO:__main__:Epoch 9/7	Loss value: 836.366943359375
INFO:__main__:Epoch 9/8	Loss value: 949.6336059570312
INFO:__main__:Epoch 9/9	Loss value: 1064.1300048828125
INFO:__main__:Epoch 9/10	Loss value: 1170.78466796875
INFO:__main__:Epoch 10/1	Loss value: 2.8532 ± 0.0499
INFO:__main__:Epoch 10/1	Loss value: 95.59754943847656
INFO:__main__:Epoch 10/2	Loss value: 256.1995544433594
INFO:__main__:Epoch 10/3	Loss value: 415.345458984375
INFO:__main__:Epoch 10/4	Loss value: 557.3989868164062
INFO:__main__:Epoch 10/5	Loss value: 700.972412109375
INFO:__main__:Epoch 10/6	Loss value: 838.5853881835938
INFO:__main__:Epoch 10/7	Loss value: 982.8052368164062
INFO:__main__:Epoch 10/8	Loss value: 1111.483154296875
INFO:__main__:Epoch 10/9	Loss value: 1246.389404296875
INFO:__main__:Epoch 10/10	Loss value: 1363.168701171875
INFO:__main__:Epoch 11/1	Loss value: 2.5844 ± 0.0384
INFO:__main__:Epoch 11/1	Loss value: 90.58894348144531
INFO:__main__:Epoch 11/2	Loss value: 234.47836303710938
INFO:__main__:Epoch 11/3	Loss value: 364.4402160644531
INFO:__main__:Epoch 11/4	Loss value: 488.31915283203125
INFO:__main__:Epoch 11/5	Loss value: 608.6617431640625
INFO:__main__:Epoch 11/6	Loss value: 731.4077758789062
INFO:__main__:Epoch 11/7	Loss value: 844.6109619140625
INFO:__main__:Epoch 11/8	Loss value: 959.2591552734375
INFO:__main__:Epoch 11/9	Loss value: 1067.35546875
INFO:__main__:Epoch 11/10	Loss value: 1168.385986328125
INFO:__main__:Epoch 12/1	Loss value: 2.3979 ± 0.0434
INFO:__main__:Epoch 12/1	Loss value: 79.91600036621094
INFO:__main__:Epoch 12/2	Loss value: 230.9688720703125
INFO:__main__:Epoch 12/3	Loss value: 385.2007751464844
INFO:__main__:Epoch 12/4	Loss value: 507.870849609375
INFO:__main__:Epoch 12/5	Loss value: 643.5989990234375
INFO:__main__:Epoch 12/6	Loss value: 768.223388671875
INFO:__main__:Epoch 12/7	Loss value: 894.1620483398438
INFO:__main__:Epoch 12/8	Loss value: 1019.3241577148438
INFO:__main__:Epoch 12/9	Loss value: 1133.106689453125
INFO:__main__:Epoch 12/10	Loss value: 1245.9158935546875
INFO:__main__:Epoch 13/1	Loss value: 2.2047 ± 0.0246
INFO:__main__:Epoch 13/1	Loss value: 94.26434326171875
INFO:__main__:Epoch 13/2	Loss value: 258.0623474121094
INFO:__main__:Epoch 13/3	Loss value: 395.2290954589844
INFO:__main__:Epoch 13/4	Loss value: 526.0809326171875
INFO:__main__:Epoch 13/5	Loss value: 646.218994140625
INFO:__main__:Epoch 13/6	Loss value: 766.4026489257812
INFO:__main__:Epoch 13/7	Loss value: 877.7215576171875
INFO:__main__:Epoch 13/8	Loss value: 989.1353759765625
INFO:__main__:Epoch 13/9	Loss value: 1094.05029296875
INFO:__main__:Epoch 13/10	Loss value: 1191.3203125
INFO:__main__:Epoch 14/1	Loss value: 2.0813 ± 0.0380
INFO:__main__:Epoch 14/1	Loss value: 90.2054672241211
INFO:__main__:Epoch 14/2	Loss value: 222.7648468017578
INFO:__main__:Epoch 14/3	Loss value: 355.99237060546875
INFO:__main__:Epoch 14/4	Loss value: 470.3455505371094
INFO:__main__:Epoch 14/5	Loss value: 572.9320068359375
INFO:__main__:Epoch 14/6	Loss value: 682.7808227539062
INFO:__main__:Epoch 14/7	Loss value: 788.0059814453125
INFO:__main__:Epoch 14/8	Loss value: 885.8945922851562
INFO:__main__:Epoch 14/9	Loss value: 980.7346801757812
INFO:__main__:Epoch 14/10	Loss value: 1075.2105712890625
INFO:__main__:Epoch 15/1	Loss value: 1.9502 ± 0.0283
INFO:__main__:Epoch 15/1	Loss value: 130.24859619140625
INFO:__main__:Epoch 15/2	Loss value: 319.31378173828125
INFO:__main__:Epoch 15/3	Loss value: 509.33135986328125
INFO:__main__:Epoch 15/4	Loss value: 665.1360473632812
INFO:__main__:Epoch 15/5	Loss value: 823.8536376953125
INFO:__main__:Epoch 15/6	Loss value: 970.2908325195312
INFO:__main__:Epoch 15/7	Loss value: 1117.1494140625
INFO:__main__:Epoch 15/8	Loss value: 1252.92529296875
INFO:__main__:Epoch 15/9	Loss value: 1378.828125
INFO:__main__:Epoch 15/10	Loss value: 1501.7156982421875
INFO:__main__:Epoch 16/1	Loss value: 1.8700 ± 0.0436
INFO:__main__:Epoch 16/1	Loss value: 99.4839096069336
INFO:__main__:Epoch 16/2	Loss value: 286.62646484375
INFO:__main__:Epoch 16/3	Loss value: 464.12646484375
INFO:__main__:Epoch 16/4	Loss value: 617.6766967773438
INFO:__main__:Epoch 16/5	Loss value: 768.6605224609375
INFO:__main__:Epoch 16/6	Loss value: 906.9293823242188
INFO:__main__:Epoch 16/7	Loss value: 1050.640380859375
INFO:__main__:Epoch 16/8	Loss value: 1181.0064697265625
INFO:__main__:Epoch 16/9	Loss value: 1306.41162109375
INFO:__main__:Epoch 16/10	Loss value: 1426.395751953125
INFO:__main__:Epoch 17/1	Loss value: 1.7621 ± 0.0157
INFO:__main__:Epoch 17/1	Loss value: 92.185302734375
INFO:__main__:Epoch 17/2	Loss value: 284.5326232910156
INFO:__main__:Epoch 17/3	Loss value: 465.41357421875
INFO:__main__:Epoch 17/4	Loss value: 621.428955078125
INFO:__main__:Epoch 17/5	Loss value: 773.66845703125
INFO:__main__:Epoch 17/6	Loss value: 921.8056030273438
INFO:__main__:Epoch 17/7	Loss value: 1065.903076171875
INFO:__main__:Epoch 17/8	Loss value: 1199.7479248046875
INFO:__main__:Epoch 17/9	Loss value: 1330.0518798828125
INFO:__main__:Epoch 17/10	Loss value: 1454.81201171875
INFO:__main__:Epoch 18/1	Loss value: 1.6942 ± 0.0346
INFO:__main__:Epoch 18/1	Loss value: 111.75216674804688
INFO:__main__:Epoch 18/2	Loss value: 289.0792236328125
INFO:__main__:Epoch 18/3	Loss value: 467.99346923828125
INFO:__main__:Epoch 18/4	Loss value: 613.46142578125
INFO:__main__:Epoch 18/5	Loss value: 753.604736328125
INFO:__main__:Epoch 18/6	Loss value: 894.119384765625
INFO:__main__:Epoch 18/7	Loss value: 1022.2001953125
INFO:__main__:Epoch 18/8	Loss value: 1147.1729736328125
INFO:__main__:Epoch 18/9	Loss value: 1266.08642578125
INFO:__main__:Epoch 18/10	Loss value: 1368.8175048828125

In [ ]:
for eval_name, eval_triples in [('valid', valid_triples), ('test', test_triples)]:

    ranks_subj, ranks_obj = [], []
    filtered_ranks_subj, filtered_ranks_obj = [], []

    for _i, (s, p, o) in enumerate(eval_triples):
        s_idx, p_idx, o_idx = entity_to_idx[s], predicate_to_idx[p], entity_to_idx[o]

        Xs = np.full(shape=(nb_entities,), fill_value=s_idx, dtype=np.int32)
        Xp = np.full(shape=(nb_entities,), fill_value=p_idx, dtype=np.int32)
        Xo = np.full(shape=(nb_entities,), fill_value=o_idx, dtype=np.int32)

        feed_dict_corrupt_subj = {subject_inputs: np.arange(nb_entities), predicate_inputs: Xp, object_inputs: Xo}
        feed_dict_corrupt_obj = {subject_inputs: Xs, predicate_inputs: Xp, object_inputs: np.arange(nb_entities)}

        # scores of (1, p, o), (2, p, o), .., (N, p, o)
        scores_subj = session.run(scores, feed_dict=feed_dict_corrupt_subj)

        # scores of (s, p, 1), (s, p, 2), .., (s, p, N)
        scores_obj = session.run(scores, feed_dict=feed_dict_corrupt_obj)

        ranks_subj += [1 + np.sum(scores_subj > scores_subj[s_idx])]
        ranks_obj += [1 + np.sum(scores_obj > scores_obj[o_idx])]

        filtered_scores_subj = scores_subj.copy()
        filtered_scores_obj = scores_obj.copy()

        rm_idx_s = [entity_to_idx[fs] for (fs, fp, fo) in all_triples if fs != s and fp == p and fo == o]
        rm_idx_o = [entity_to_idx[fo] for (fs, fp, fo) in all_triples if fs == s and fp == p and fo != o]

        filtered_scores_subj[rm_idx_s] = - np.inf
        filtered_scores_obj[rm_idx_o] = - np.inf

        filtered_ranks_subj += [1 + np.sum(filtered_scores_subj > filtered_scores_subj[s_idx])]
        filtered_ranks_obj += [1 + np.sum(filtered_scores_obj > filtered_scores_obj[o_idx])]

        if _i % 1000 == 0:
            logger.info('{}/{} ..'.format(_i, len(eval_triples)))
        
        
    ranks = ranks_subj + ranks_obj
    filtered_ranks = filtered_ranks_subj + filtered_ranks_obj

    for setting_name, setting_ranks in [('Raw', ranks), ('Filtered', filtered_ranks)]:
        mean_rank = np.mean(setting_ranks)
        logger.info('[{}] {} Mean Rank: {}'.format(eval_name, setting_name, mean_rank))
        for k in [1, 3, 5, 10]:
            hits_at_k = np.mean(np.asarray(setting_ranks) <= k) * 100
            logger.info('[{}] {} Hits@{}: {}'.format(eval_name, setting_name, k, hits_at_k))

In [ ]: