In [3]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib import gridspec
In [4]:
import sys
sys.path.append('..')
In [5]:
from tasks import *
input_dim=10
output_dim=10
sess = tf.InteractiveSession()
cell = NTMCell(input_dim=input_dim, output_dim=output_dim)
ntm = NTM(cell, sess, 1, 10, 100, forward_only=True)
ntm.load('../checkpoint', 'copy')
In [6]:
copy(ntm, 5, sess)
In [32]:
def plot(ntm, seq_length, sess):
seq, outputs, read_w, write_w, loss = copy(ntm, seq_length, sess, print_=False)
read_w[0] = np.squeeze(read_w[0])
write_w[0] = np.squeeze(write_w[0])
shape2 = list(np.array(outputs).shape)
shape2[0] += 2
shape3 = list(np.array(read_w).shape)
shape3[0] += 2
z1 = np.zeros_like(outputs)
z2 = np.zeros(shape2)
s1 = np.zeros_like(seq[0]); s1[0]=1
s2 = np.zeros_like(seq[0]); s2[1]=1
seq = [s1] + seq + [s2]
seq = np.r_[np.array(seq), z1]
outputs = np.r_[z2, outputs]
if seq_length >= 80:
fig = plt.figure(1,figsize=(20,16))
gs = gridspec.GridSpec(4, 1, height_ratios=[0.4, 0.4, 1.6, 1.6])
elif seq_length >= 60:
fig = plt.figure(1,figsize=(20,14))
gs = gridspec.GridSpec(4, 1, height_ratios=[0.6, 0.6, 1.4, 1.4])
elif seq_length >= 50:
fig = plt.figure(1,figsize=(20,14))
gs = gridspec.GridSpec(4, 1, height_ratios=[0.8, 0.8, 1.2, 1.2])
elif seq_length >= 20:
fig = plt.figure(1,figsize=(20,14))
gs = gridspec.GridSpec(4, 1, height_ratios=[0.9, 0.9, 1.1, 1.1])
else:
fig = plt.figure(1,figsize=(20,10))
gs = gridspec.GridSpec(4, 1, height_ratios=[1, 1, 1, 1])
ax0 = plt.subplot(gs[0])
ax0.imshow(seq.T, interpolation='nearest')
ax0.set_ylabel('input')
ax1 = plt.subplot(gs[1])
ax1.imshow(outputs.T, interpolation='nearest')
ax1.set_xlabel('time')
ax1.set_ylabel('output')
ax2 = plt.subplot(gs[2])
ax2.imshow(read_w[1:-1], cmap='Greys', interpolation='nearest')
ax2.set_xlabel('write weight')
ax2.set_ylabel('time')
ax3 = plt.subplot(gs[3])
ax3.imshow(read_w[1:-1], cmap='Greys', interpolation='nearest')
ax3.set_xlabel('read weight')
ax3.set_ylabel('time')
print("Loss : %f" % loss)
In [20]:
plot(ntm, 5, sess)
In [21]:
plot(ntm, 10, sess)
In [33]:
plot(ntm, 20, sess)
In [34]:
plot(ntm, 30, sess)
In [37]:
plot(ntm, 50, sess)
In [38]:
plot(ntm, 60, sess)
In [39]:
plot(ntm, 70, sess)
In [40]:
plot(ntm, 80, sess)
In [44]:
plot(ntm, 100, sess)
In [ ]: