In [3]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib import gridspec

In [4]:
import sys
sys.path.append('..')

In [5]:
from tasks import *

input_dim=10
output_dim=10

sess = tf.InteractiveSession()

cell = NTMCell(input_dim=input_dim, output_dim=output_dim)
ntm = NTM(cell, sess, 1, 10, 100, forward_only=True)

ntm.load('../checkpoint', 'copy')


 [*] Building a NTM model
Percent: [####################] 100.00% Finished.
 [*] Build a NTM model finished
 [*] Reading checkpoints...

In [6]:
copy(ntm, 5, sess)


 true output : 
  ##  #   
  #### ## 
   #######
  # ## # #
    # ## #
 predicted output :
  ##  #   
  #### ## 
   #######
  # ## # #
    # ## #
 Loss : 0.000001

In [32]:
def plot(ntm, seq_length, sess):
    seq, outputs, read_w, write_w, loss = copy(ntm, seq_length, sess, print_=False)
    read_w[0] = np.squeeze(read_w[0])
    write_w[0] = np.squeeze(write_w[0])

    shape2 = list(np.array(outputs).shape)
    shape2[0] += 2

    shape3 = list(np.array(read_w).shape)
    shape3[0] += 2

    z1 = np.zeros_like(outputs)
    z2 = np.zeros(shape2)

    s1 = np.zeros_like(seq[0]); s1[0]=1
    s2 = np.zeros_like(seq[0]); s2[1]=1

    seq = [s1] + seq + [s2]

    seq = np.r_[np.array(seq), z1]
    outputs = np.r_[z2, outputs]

    if seq_length >= 80:
        fig = plt.figure(1,figsize=(20,16))
        gs = gridspec.GridSpec(4, 1, height_ratios=[0.4, 0.4, 1.6, 1.6])
    elif seq_length >= 60:
        fig = plt.figure(1,figsize=(20,14))
        gs = gridspec.GridSpec(4, 1, height_ratios=[0.6, 0.6, 1.4, 1.4])
    elif seq_length >= 50:
        fig = plt.figure(1,figsize=(20,14))
        gs = gridspec.GridSpec(4, 1, height_ratios=[0.8, 0.8, 1.2, 1.2])
    elif seq_length >= 20:
        fig = plt.figure(1,figsize=(20,14))
        gs = gridspec.GridSpec(4, 1, height_ratios=[0.9, 0.9, 1.1, 1.1])
    else:
        fig = plt.figure(1,figsize=(20,10))
        gs = gridspec.GridSpec(4, 1, height_ratios=[1, 1, 1, 1])

    ax0 = plt.subplot(gs[0])
    ax0.imshow(seq.T, interpolation='nearest')
    ax0.set_ylabel('input')
    
    ax1 = plt.subplot(gs[1])
    ax1.imshow(outputs.T, interpolation='nearest')
    ax1.set_xlabel('time')
    ax1.set_ylabel('output')
    
    ax2 = plt.subplot(gs[2])
    ax2.imshow(read_w[1:-1], cmap='Greys', interpolation='nearest')
    ax2.set_xlabel('write weight')
    ax2.set_ylabel('time')
    
    ax3 = plt.subplot(gs[3])
    ax3.imshow(read_w[1:-1], cmap='Greys', interpolation='nearest')
    ax3.set_xlabel('read weight')
    ax3.set_ylabel('time')

    print("Loss : %f" % loss)

In [20]:
plot(ntm, 5, sess)


Loss : 0.000000

In [21]:
plot(ntm, 10, sess)


Loss : 0.000003

In [33]:
plot(ntm, 20, sess)


Loss : 0.000009

In [34]:
plot(ntm, 30, sess)


Loss : 0.000024

In [37]:
plot(ntm, 50, sess)


Loss : 0.000257

In [38]:
plot(ntm, 60, sess)


Loss : 0.029731

In [39]:
plot(ntm, 70, sess)


Loss : 0.035636

In [40]:
plot(ntm, 80, sess)


Loss : 0.002402

In [44]:
plot(ntm, 100, sess)


Loss : 1113.521484

In [ ]: