Testing the trained weight matrices (not in an ensemble)


In [3]:
import nengo
import numpy as np
import cPickle
import matplotlib.pyplot as plt
from matplotlib import pylab
import matplotlib.animation as animation

Load the weight matrices from the training


In [4]:
#Weight matrices generated by the neural network after training

#Maps the label vectors to the neuron activity of the ensemble
label_weights = cPickle.load(open("label_weights1000.p", "rb"))
#Maps the activity of the neurons to the visual representation of the image
activity_to_img_weights = cPickle.load(open("activity_to_img_weights_scale1000.p", "rb"))
#Maps the activity of the neurons of an image with the activity of the neurons of an image scaled
scale_up_weights = cPickle.load(open("scale_up_weights1000.p", "rb"))
scale_down_weights = cPickle.load(open("scale_down_weights1000.p", "rb"))

#Create the pointers for the numbers
temp = np.diag([1]*10)

ZERO = temp[0]
ONE =  temp[1]
TWO =  temp[2]
THREE= temp[3]
FOUR = temp[4]
FIVE = temp[5]
SIX =  temp[6]
SEVEN =temp[7]
EIGHT= temp[8]
NINE = temp[9]

labels =[ZERO,ONE,TWO,THREE,FOUR,FIVE,SIX,SEVEN,EIGHT,NINE]

In [ ]:
#Visualize the one hot representation
print(ZERO)
print(ONE)

Visualize the digit from one hot representation through the activity weight matrix to the image representation

  • Image is average digit from mnist dataset

In [ ]:
#Change this to imagine different digits
imagine = ZERO
#Can also imagine combitnations of numbers (ZERO + ONE)

#Label to activity
test_activity = np.dot(imagine,label_weights)
#Image decoded 
test_output_img = np.dot(test_activity, activity_to_img_weights)

plt.imshow(test_output_img.reshape(28,28),cmap='gray')
plt.show()

Visualize the rotation of the image using the weight matrix from activity to activity

  • does not use the weight matrix used on the recurrent connection

In [6]:
#Change this to visualize different digits
imagine = ZERO 

#How long the animation should go for
frames=5

#Make a list of the activation of rotated images and add first frame
rot_seq = []
rot_seq.append(np.dot(imagine,label_weights)) #Map the label vector to the activity vector
test_output_img = np.dot(rot_seq[0], activity_to_img_weights) #Map the activity to the visual representation

#add the rest of the frames, using the previous frame to calculate the current frame
for i in range(1,frames):
    rot_seq.append(np.dot(rot_seq[i-1],scale_down_weights)) #add the activity of the current image to the list
    test_output_img = np.dot(rot_seq[i], activity_to_img_weights) #map the new activity to the visual image
for i in range(1,frames*2):
    rot_seq.append(np.dot(rot_seq[frames+i-2],scale_up_weights)) #add the activity of the current image to the list
    test_output_img = np.dot(rot_seq[i], activity_to_img_weights) #map the new activity to the visual image 

#Animation of rotation
fig = plt.figure()

def updatefig(i):
    image_vector = np.dot(rot_seq[i], activity_to_img_weights) #map the activity to the image representation
    im = pylab.imshow(np.reshape(image_vector,(28,28), 'F').T, cmap=plt.get_cmap('Greys_r'),animated=True)
    
    return im,

ani = animation.FuncAnimation(fig, updatefig, interval=100, blit=True)
plt.show()


Exception in Tkinter callback
Traceback (most recent call last):
  File "C:\Python27\lib\lib-tk\Tkinter.py", line 1536, in __call__
    return self.func(*args)
  File "C:\Python27\lib\lib-tk\Tkinter.py", line 587, in callit
    func(*args)
  File "C:\Python27\lib\site-packages\matplotlib\backends\backend_tkagg.py", line 147, in _on_timer
    TimerBase._on_timer(self)
  File "C:\Python27\lib\site-packages\matplotlib\backend_bases.py", line 1305, in _on_timer
    ret = func(*args, **kwargs)
  File "C:\Python27\lib\site-packages\matplotlib\animation.py", line 1021, in _step
    still_going = Animation._step(self, *args)
  File "C:\Python27\lib\site-packages\matplotlib\animation.py", line 827, in _step
    self._draw_next_frame(framedata, self._blit)
  File "C:\Python27\lib\site-packages\matplotlib\animation.py", line 846, in _draw_next_frame
    self._draw_frame(framedata)
  File "C:\Python27\lib\site-packages\matplotlib\animation.py", line 1212, in _draw_frame
    self._drawn_artists = self._func(framedata, *self._args)
  File "<ipython-input-6-c7be4de48ccb>", line 24, in updatefig
    image_vector = np.dot(rot_seq[i], activity_to_img_weights) #map the activity to the image representation
IndexError: list index out of range