In [1]:
import os

import pprint
import numpy as np
import tensorflow as tf

from data import read_data, pad_data, depad_data
from model import MemN2N

pp = pprint.PrettyPrinter()

flags = tf.app.flags

flags.DEFINE_integer("edim", 20, "internal state dimension [20]")
flags.DEFINE_integer("nhop", 3, "number of hops [3]")
flags.DEFINE_integer("mem_size", 50, "maximum number of sentences that can be encoded into memory [50]")
flags.DEFINE_integer("batch_size", 32, "batch size to use during training [32]")
flags.DEFINE_integer("nepoch", 100, "number of epoch to use during training [100]")
flags.DEFINE_integer("anneal_epoch", 25, "anneal the learning rate every <anneal_epoch> epochs [25]")
flags.DEFINE_integer("babi_task", 7, "index of bAbI task for the network to learn [1]")
flags.DEFINE_float("init_lr", 0.01, "initial learning rate [0.01]")
flags.DEFINE_float("anneal_rate", 0.5, "learning rate annealing rate [0.5]")
flags.DEFINE_float("init_mean", 0., "weight initialization mean [0.]")
flags.DEFINE_float("init_std", 0.1, "weight initialization std [0.1]")
flags.DEFINE_float("max_grad_norm", 40, "clip gradients to this norm [40]")
flags.DEFINE_string("data_dir", "./bAbI/en-valid", "dataset directory [./bAbI/en_valid]")
flags.DEFINE_string("checkpoint_dir", "./checkpoints", "checkpoint directory [./checkpoints]")
flags.DEFINE_boolean("lin_start", False, "True for linear start training, False for otherwise [False]")
flags.DEFINE_boolean("is_test", False, "True for testing, False for training [False]")
flags.DEFINE_boolean("show_progress", False, "print progress [False]")

FLAGS = flags.FLAGS

word2idx = {}
max_words = 0
max_sentences = 0

if not os.path.exists(FLAGS.checkpoint_dir):
    os.makedirs(FLAGS.checkpoint_dir)

train_stories, train_questions, max_words, max_sentences = read_data('{}/qa{}_test.txt'.format(FLAGS.data_dir, FLAGS.babi_task), word2idx, max_words, max_sentences)
valid_stories, valid_questions, max_words, max_sentences = read_data('{}/qa{}_valid.txt'.format(FLAGS.data_dir, FLAGS.babi_task), word2idx, max_words, max_sentences)
test_stories, test_questions, max_words, max_sentences = read_data('{}/qa{}_test.txt'.format(FLAGS.data_dir, FLAGS.babi_task), word2idx, max_words, max_sentences)

pad_data(train_stories, train_questions, max_words, max_sentences)
pad_data(valid_stories, valid_questions, max_words, max_sentences)
pad_data(test_stories, test_questions, max_words, max_sentences)

idx2word = dict(zip(word2idx.values(), word2idx.keys()))
FLAGS.nwords = len(word2idx)
FLAGS.max_words = max_words
FLAGS.max_sentences = max_sentences

pp.pprint(flags.FLAGS.__flags)

with tf.Session() as sess:
    model = MemN2N(FLAGS, sess)
    model.build_model()

    if FLAGS.is_test:
        model.run(valid_stories, valid_questions, test_stories, test_questions)
    else:
        model.run(train_stories, train_questions, valid_stories, valid_questions)
    
    predictions, target = model.predict(train_stories, train_questions)


{'anneal_epoch': 25,
 'anneal_rate': 0.5,
 'babi_task': 7,
 'batch_size': 32,
 'checkpoint_dir': './checkpoints',
 'data_dir': './bAbI/en-valid',
 'edim': 20,
 'init_lr': 0.01,
 'init_mean': 0.0,
 'init_std': 0.1,
 'is_test': False,
 'lin_start': False,
 'max_grad_norm': 40,
 'max_sentences': 33,
 'max_words': 16,
 'mem_size': 50,
 'nepoch': 100,
 'nhop': 3,
 'nwords': 44,
 'show_progress': False}
{'learning_rate': 0.01, 'epoch': 0, 'loss': 1.3494583702087402, 'validation_loss': 1.170234591960907}
{'learning_rate': 0.01, 'epoch': 1, 'loss': 0.95766279029846191, 'validation_loss': 1.2053432273864746}
{'learning_rate': 0.01, 'epoch': 2, 'loss': 0.91837644863128665, 'validation_loss': 1.1492957496643066}
{'learning_rate': 0.01, 'epoch': 3, 'loss': 0.88957553339004514, 'validation_loss': 1.0938487553596496}
{'learning_rate': 0.01, 'epoch': 4, 'loss': 0.85897242641448979, 'validation_loss': 1.0075859284400941}
{'learning_rate': 0.01, 'epoch': 5, 'loss': 0.8295915842056274, 'validation_loss': 0.96755825281143193}
{'learning_rate': 0.01, 'epoch': 6, 'loss': 0.79424847030639645, 'validation_loss': 0.95458817481994629}
{'learning_rate': 0.01, 'epoch': 7, 'loss': 0.76740477895736692, 'validation_loss': 0.90062611103057866}
{'learning_rate': 0.01, 'epoch': 8, 'loss': 0.74515783977508543, 'validation_loss': 0.87097166061401365}
{'learning_rate': 0.01, 'epoch': 9, 'loss': 0.72283336353301997, 'validation_loss': 0.85235928535461425}
{'learning_rate': 0.01, 'epoch': 10, 'loss': 0.70465670108795164, 'validation_loss': 0.83033002853393556}
{'learning_rate': 0.01, 'epoch': 11, 'loss': 0.69624329471588131, 'validation_loss': 0.83999107360839842}
{'learning_rate': 0.01, 'epoch': 12, 'loss': 0.68744788360595699, 'validation_loss': 0.8571283888816833}
{'learning_rate': 0.01, 'epoch': 13, 'loss': 0.675952267408371, 'validation_loss': 0.83289131879806522}
{'learning_rate': 0.01, 'epoch': 14, 'loss': 0.66717684698104862, 'validation_loss': 0.81261657953262334}
{'learning_rate': 0.01, 'epoch': 15, 'loss': 0.65942105603218082, 'validation_loss': 0.7954740738868713}
{'learning_rate': 0.01, 'epoch': 16, 'loss': 0.652815468788147, 'validation_loss': 0.78586914300918576}
{'learning_rate': 0.01, 'epoch': 17, 'loss': 0.64649562740325928, 'validation_loss': 0.78121205329895016}
{'learning_rate': 0.01, 'epoch': 18, 'loss': 0.64100163578987124, 'validation_loss': 0.78237061500549321}
{'learning_rate': 0.01, 'epoch': 19, 'loss': 0.63588298559188838, 'validation_loss': 0.78696868658065799}
{'learning_rate': 0.01, 'epoch': 20, 'loss': 0.62918309259414673, 'validation_loss': 0.78879983663558961}
{'learning_rate': 0.01, 'epoch': 21, 'loss': 0.62390866684913637, 'validation_loss': 0.79172858476638797}
{'learning_rate': 0.01, 'epoch': 22, 'loss': 0.61859100008010859, 'validation_loss': 0.78638511896133423}
{'learning_rate': 0.01, 'epoch': 23, 'loss': 0.60962650299072263, 'validation_loss': 0.78300140142440799}
{'learning_rate': 0.01, 'epoch': 24, 'loss': 0.61117627906799321, 'validation_loss': 0.76510161399841303}
{'learning_rate': 0.01, 'epoch': 25, 'loss': 0.60199075412750247, 'validation_loss': 0.7767837142944336}
{'learning_rate': 0.005, 'epoch': 26, 'loss': 0.57104834389686587, 'validation_loss': 0.69777816534042358}
{'learning_rate': 0.005, 'epoch': 27, 'loss': 0.56353433656692509, 'validation_loss': 0.69042152643203736}
{'learning_rate': 0.005, 'epoch': 28, 'loss': 0.55671536469459537, 'validation_loss': 0.68690115571022037}
{'learning_rate': 0.005, 'epoch': 29, 'loss': 0.5513715415000916, 'validation_loss': 0.68400438904762273}
{'learning_rate': 0.005, 'epoch': 30, 'loss': 0.54719126892089842, 'validation_loss': 0.67960355997085575}
{'learning_rate': 0.005, 'epoch': 31, 'loss': 0.5435743436813355, 'validation_loss': 0.6734527051448822}
{'learning_rate': 0.005, 'epoch': 32, 'loss': 0.54025116920471195, 'validation_loss': 0.66722591996192937}
{'learning_rate': 0.005, 'epoch': 33, 'loss': 0.53717302989959714, 'validation_loss': 0.66192896723747252}
{'learning_rate': 0.005, 'epoch': 34, 'loss': 0.53434214162826543, 'validation_loss': 0.65757290482521058}
{'learning_rate': 0.005, 'epoch': 35, 'loss': 0.53176144170761108, 'validation_loss': 0.65403122425079341}
{'learning_rate': 0.005, 'epoch': 36, 'loss': 0.5293942174911499, 'validation_loss': 0.65114455342292787}
{'learning_rate': 0.005, 'epoch': 37, 'loss': 0.52716587305068974, 'validation_loss': 0.64850171327590944}
{'learning_rate': 0.005, 'epoch': 38, 'loss': 0.52501447486877439, 'validation_loss': 0.64542953848838802}
{'learning_rate': 0.005, 'epoch': 39, 'loss': 0.5229183745384216, 'validation_loss': 0.6413954830169678}
{'learning_rate': 0.005, 'epoch': 40, 'loss': 0.52090455198287966, 'validation_loss': 0.63670857191085817}
{'learning_rate': 0.005, 'epoch': 41, 'loss': 0.51910312509536738, 'validation_loss': 0.63250734329223635}
{'learning_rate': 0.005, 'epoch': 42, 'loss': 0.51765966129302976, 'validation_loss': 0.62931455492973332}
{'learning_rate': 0.005, 'epoch': 43, 'loss': 0.51621963071823118, 'validation_loss': 0.62659877777099604}
{'learning_rate': 0.005, 'epoch': 44, 'loss': 0.51459691524505613, 'validation_loss': 0.62400446414947508}
{'learning_rate': 0.005, 'epoch': 45, 'loss': 0.51296709990501399, 'validation_loss': 0.6217210423946381}
{'learning_rate': 0.005, 'epoch': 46, 'loss': 0.51132491016387938, 'validation_loss': 0.62004899978637695}
{'learning_rate': 0.005, 'epoch': 47, 'loss': 0.50957664537429814, 'validation_loss': 0.61909135818481442}
{'learning_rate': 0.005, 'epoch': 48, 'loss': 0.50762390160560611, 'validation_loss': 0.6189188098907471}
{'learning_rate': 0.005, 'epoch': 49, 'loss': 0.50551674079895015, 'validation_loss': 0.61961162745952603}
{'learning_rate': 0.005, 'epoch': 50, 'loss': 0.50350008440017702, 'validation_loss': 0.62075054883956904}
{'learning_rate': 0.0025, 'epoch': 51, 'loss': 0.49740002393722532, 'validation_loss': 0.58172676682472224}
{'learning_rate': 0.0025, 'epoch': 52, 'loss': 0.48924163341522214, 'validation_loss': 0.58311765193939213}
{'learning_rate': 0.0025, 'epoch': 53, 'loss': 0.48645117712020874, 'validation_loss': 0.58871693253517154}
{'learning_rate': 0.0025, 'epoch': 54, 'loss': 0.48411962127685548, 'validation_loss': 0.5995426750183106}
{'learning_rate': 0.0025, 'epoch': 55, 'loss': 0.48188404774665833, 'validation_loss': 0.61013479828834538}
{'learning_rate': 0.0025, 'epoch': 56, 'loss': 0.47945734596252443, 'validation_loss': 0.61313316941261287}
{'learning_rate': 0.0025, 'epoch': 57, 'loss': 0.47698738622665404, 'validation_loss': 0.61105804562568666}
{'learning_rate': 0.0025, 'epoch': 58, 'loss': 0.47472467184066774, 'validation_loss': 0.60705380916595464}
{'learning_rate': 0.0025, 'epoch': 59, 'loss': 0.4726275420188904, 'validation_loss': 0.60334622025489804}
{'learning_rate': 0.0025, 'epoch': 60, 'loss': 0.47061111617088319, 'validation_loss': 0.60120776414871213}
{'learning_rate': 0.0025, 'epoch': 61, 'loss': 0.46868098974227906, 'validation_loss': 0.60025244951248169}
{'learning_rate': 0.0025, 'epoch': 62, 'loss': 0.4669323387145996, 'validation_loss': 0.5997404909133911}
{'learning_rate': 0.0025, 'epoch': 63, 'loss': 0.465348153591156, 'validation_loss': 0.5993447291851044}
{'learning_rate': 0.0025, 'epoch': 64, 'loss': 0.46386429655551908, 'validation_loss': 0.59890996098518368}
{'learning_rate': 0.0025, 'epoch': 65, 'loss': 0.46244070756435396, 'validation_loss': 0.5984150516986847}
{'learning_rate': 0.0025, 'epoch': 66, 'loss': 0.46101350128650664, 'validation_loss': 0.5979280257225037}
{'learning_rate': 0.0025, 'epoch': 67, 'loss': 0.45950625443458559, 'validation_loss': 0.59754816889762874}
{'learning_rate': 0.0025, 'epoch': 68, 'loss': 0.45817958438396456, 'validation_loss': 0.5972286522388458}
{'learning_rate': 0.0025, 'epoch': 69, 'loss': 0.45749873352050779, 'validation_loss': 0.59666479945182804}
{'learning_rate': 0.0025, 'epoch': 70, 'loss': 0.45642544758319853, 'validation_loss': 0.59621699810028073}
{'learning_rate': 0.0025, 'epoch': 71, 'loss': 0.45486854600906373, 'validation_loss': 0.59617541313171385}
{'learning_rate': 0.0025, 'epoch': 72, 'loss': 0.45493191504478453, 'validation_loss': 0.59531668186187747}
{'learning_rate': 0.0025, 'epoch': 73, 'loss': 0.45575080823898317, 'validation_loss': 0.59497429132461543}
{'learning_rate': 0.0025, 'epoch': 74, 'loss': 0.45119453442096707, 'validation_loss': 0.59549145340919496}
{'learning_rate': 0.0025, 'epoch': 75, 'loss': 0.45521457958221434, 'validation_loss': 0.59324149847030638}
{'learning_rate': 0.00125, 'epoch': 76, 'loss': 0.4412108964920044, 'validation_loss': 0.57320029735565181}
{'learning_rate': 0.00125, 'epoch': 77, 'loss': 0.43739699697494505, 'validation_loss': 0.57419827461242678}
{'learning_rate': 0.00125, 'epoch': 78, 'loss': 0.43764109563827513, 'validation_loss': 0.57321208953857417}
{'learning_rate': 0.00125, 'epoch': 79, 'loss': 0.4348056709766388, 'validation_loss': 0.57525240898132324}
{'learning_rate': 0.00125, 'epoch': 80, 'loss': 0.43579065454006194, 'validation_loss': 0.57327941894531254}
{'learning_rate': 0.00125, 'epoch': 81, 'loss': 0.43294333553314207, 'validation_loss': 0.57508021354675298}
{'learning_rate': 0.00125, 'epoch': 82, 'loss': 0.43427425003051756, 'validation_loss': 0.57270117521286013}
{'learning_rate': 0.00125, 'epoch': 83, 'loss': 0.43140757477283476, 'validation_loss': 0.57429283618927007}
{'learning_rate': 0.00125, 'epoch': 84, 'loss': 0.43299649727344514, 'validation_loss': 0.57163932085037228}
{'learning_rate': 0.00125, 'epoch': 85, 'loss': 0.43005666708946227, 'validation_loss': 0.57302159309387202}
{'learning_rate': 0.00125, 'epoch': 86, 'loss': 0.43188328969478607, 'validation_loss': 0.57013682365417484}
{'learning_rate': 0.00125, 'epoch': 87, 'loss': 0.42881473660469055, 'validation_loss': 0.57129777669906612}
{'learning_rate': 0.00125, 'epoch': 88, 'loss': 0.43085678410530093, 'validation_loss': 0.56822734117507934}
{'learning_rate': 0.00125, 'epoch': 89, 'loss': 0.42761960911750796, 'validation_loss': 0.56920595884323122}
{'learning_rate': 0.00125, 'epoch': 90, 'loss': 0.4297959452867508, 'validation_loss': 0.56604688882827758}
{'learning_rate': 0.00125, 'epoch': 91, 'loss': 0.42640608394145968, 'validation_loss': 0.56694009304046633}
{'learning_rate': 0.00125, 'epoch': 92, 'loss': 0.42875660228729245, 'validation_loss': 0.56386746048927305}
{'learning_rate': 0.00125, 'epoch': 93, 'loss': 0.42516977584362031, 'validation_loss': 0.56480321288108826}
{'learning_rate': 0.00125, 'epoch': 94, 'loss': 0.42768951416015627, 'validation_loss': 0.56189583897590634}
{'learning_rate': 0.00125, 'epoch': 95, 'loss': 0.4238904254436493, 'validation_loss': 0.56291117668151858}
{'learning_rate': 0.00125, 'epoch': 96, 'loss': 0.42654670858383181, 'validation_loss': 0.56016628265380863}
{'learning_rate': 0.00125, 'epoch': 97, 'loss': 0.42254860079288481, 'validation_loss': 0.56123896002769469}
{'learning_rate': 0.00125, 'epoch': 98, 'loss': 0.42528756022453307, 'validation_loss': 0.55860994338989256}
{'learning_rate': 0.00125, 'epoch': 99, 'loss': 0.42112210416793822, 'validation_loss': 0.55971255421638488}
 [*] Reading checkpoints...
INFO:tensorflow:Restoring parameters from ./checkpoints\MemN2N.model-3276

In [3]:
index = 25

depad_data(train_stories, train_questions)

question = train_questions[index]['question']
answer = train_questions[index]['answer']
story_index = train_questions[index]['story_index']
sentence_index = train_questions[index]['sentence_index']

story = train_stories[story_index][:sentence_index + 1]

story = [list(map(idx2word.get, sentence)) for sentence in story]
question = list(map(idx2word.get, question))
prediction = [idx2word[np.argmax(predictions[index])]]
answer = list(map(idx2word.get, answer))

print('Story:')
pp.pprint(story)
print('\nQuestion:')
pp.pprint(question)
print('\nPrediction:')
pp.pprint(prediction)
print('\nAnswer:')
pp.pprint(answer)
print('\nCorrect:')
pp.pprint(prediction == answer)


Story:
[['daniel', 'went', 'to', 'the', 'bathroom'],
 ['john', 'moved', 'to', 'the', 'office'],
 ['daniel', 'went', 'to', 'the', 'kitchen'],
 ['john', 'got', 'the', 'apple', 'there']]

Question:
['how', 'many', 'objects', 'is', 'john', 'carrying']

Prediction:
['one']

Answer:
['one']

Correct:
True

In [ ]: