Compare similarity of image after rotation

  • On the function points branch of nengo
  • On the vision branch of nengo_extras

In [1]:
import nengo
import numpy as np
import cPickle
from nengo_extras.data import load_mnist
from nengo_extras.vision import Gabor, Mask
from matplotlib import pylab
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import scipy.ndimage
from skimage.measure import compare_ssim as ssim

Load the MNIST database


In [2]:
# --- load the data
img_rows, img_cols = 28, 28

(X_train, y_train), (X_test, y_test) = load_mnist()

X_train = 2 * X_train - 1  # normalize to -1 to 1
X_test = 2 * X_test - 1  # normalize to -1 to 1

Each digit is represented by a one hot vector where the index of the 1 represents the number


In [3]:
temp = np.diag([1]*10)

ZERO = temp[0]
ONE =  temp[1]
TWO =  temp[2]
THREE= temp[3]
FOUR = temp[4]
FIVE = temp[5]
SIX =  temp[6]
SEVEN =temp[7]
EIGHT= temp[8]
NINE = temp[9]

labels =[ZERO,ONE,TWO,THREE,FOUR,FIVE,SIX,SEVEN,EIGHT,NINE]

dim =28

Load the saved weight matrices that were created by training the model


In [4]:
label_weights = cPickle.load(open("label_weights5000.p", "rb"))
activity_to_img_weights = cPickle.load(open("activity_to_img_weights5000.p", "rb"))
rotated_clockwise_after_encoder_weights =  cPickle.load(open("rotated_after_encoder_weights_clockwise5000.p", "r"))
rotated_counter_after_encoder_weights =  cPickle.load(open("rotated_after_encoder_weights5000.p", "r"))

#identity_after_encoder_weights = cPickle.load(open("identity_after_encoder_weights1000.p","r"))


#rotation_clockwise_weights = cPickle.load(open("rotation_clockwise_weights1000.p","rb"))
#rotation_counter_weights = cPickle.load(open("rotation_weights1000.p","rb"))

In [5]:
#Training with filters used on train images
#low_pass_weights = cPickle.load(open("low_pass_weights1000.p", "rb"))
#rotated_counter_after_encoder_weights_noise = cPickle.load(open("rotated_after_encoder_weights_counter_filter_noise5000.p", "r"))
#rotated_counter_after_encoder_weights_filter = cPickle.load(open("rotated_after_encoder_weights_counter_filter5000.p", "r"))

Functions to perform the inhibition of each ensemble


In [6]:
#A value of zero gives no inhibition

def inhibit_rotate_clockwise(t):
    if t < 0.5:
        return dim**2
    else:
        return 0
    
def inhibit_rotate_counter(t):
    if t < 0.5:
        return 0
    else:
        return dim**2
    
def inhibit_identity(t):
    if t < 0.3:
        return dim**2
    else:
        return dim**2

In [7]:
def intense(img):
    newImg = img.copy()
    newImg[newImg < 0] = -1
    newImg[newImg > 0] = 1
    return newImg

def node_func(t,x):
    #clean = scipy.ndimage.gaussian_filter(x, sigma=1)
    #clean = scipy.ndimage.median_filter(x, 3)
    clean = intense(x)
    return clean

In [8]:
#Create stimulus at horizontal
weight = np.dot(label_weights,activity_to_img_weights)

img = np.dot(THREE,weight)

plt.subplot(121)
pylab.imshow(img.reshape(28,28),cmap="gray")


#img =X_train[7] 
rot_img =scipy.ndimage.interpolation.rotate(img.reshape(28,28),40,reshape=False,cval=-1).ravel()

pylab.imshow(rot_img.reshape(28,28),cmap='gray')
plt.show()

In [9]:
def ssim_func(x):
    
    img1 = np.dot(x[:5000],activity_to_img_weights)
    img2 = np.dot(x[5000:],activity_to_img_weights)
    
    return ssim(img1.reshape(28,28),img2.reshape(28,28))

def activity_sim_func(x):
    u=x[:5000]
    v=x[5000:]
    
    #a= nengo.spa.similarity(u,v,normalize=True)
    a = np.dot(u,v)
    
    return a

The network where the mental imagery and rotation occurs

  • The state, seed and ensemble parameters (including encoders) must all be the same for the saved weight matrices to work
  • The number of neurons (n_hid) must be the same as was used for training
  • The input must be shown for a short period of time to be able to view the rotation
  • The recurrent connection must be from the neurons because the weight matices were trained on the neuron activities

In [10]:
rng = np.random.RandomState(9)
n_hid = 5000
model = nengo.Network(seed=3)
with model:
    #Stimulus to be matched to
    static_stim = nengo.Node(img)
    
    #Stimulus only shows for brief period of time
    rot_stim = nengo.Node(lambda t: rot_img if t < 0.1 else 0)
    
    
    ens_params = dict(
        eval_points=X_train,
        neuron_type=nengo.LIF(), 
        intercepts=nengo.dists.Choice([-0.5]),
        max_rates=nengo.dists.Choice([100]),
        )
        
    
    # linear filter used for edge detection as encoders, more plausible for human visual system
    encoders = Gabor().generate(n_hid, (11, 11), rng=rng)
    encoders = Mask((28, 28)).populate(encoders, rng=rng, flatten=True)

    #Ensemble that represents the image to be matched
    static_ens = nengo.Ensemble(n_neurons=5000, dimensions=784,seed=3, encoders=encoders, **ens_params) #Direct? cannot because similarity between activities
    nengo.Connection(static_stim, static_ens)
    
    #Ensemble that represents the image and will be rotated
    ens = nengo.Ensemble(n_hid, dim**2, seed=3, encoders=encoders, **ens_params)
    

    #Connect stimulus to ensemble, transform using learned weight matrices
    nengo.Connection(rot_stim, ens)
    
    #Recurrent connection on the neurons of the ensemble to perform the rotation
    nengo.Connection(ens.neurons, ens.neurons, transform = rotated_clockwise_after_encoder_weights.T, synapse=0.1)

    #Bring two images together to calculate similarity
    #Ideally: 
    #combine = nengo.Ensemble(10000, 784*2) #Not direct, connections to actual neurons
    combine = nengo.Ensemble(1000, 10000, neuron_type=nengo.Direct())
    
    nengo.Connection(static_ens.neurons,combine[:5000])
    nengo.Connection(ens.neurons, combine[5000:])
    
    #structural similarity measure
    ssim_node = nengo.Node(None, size_in=1)
    nengo.Connection(combine, ssim_node, function = ssim_func)
    
    #neural activity similarity measure
    activity_sim_node = nengo.Node(None,size_in=1)
    nengo.Connection(combine, activity_sim_node, function=activity_sim_func)
    
    #Collect output, use synapse for smoothing
    probe = nengo.Probe(ens.neurons,synapse=0.1)
    static_probe = nengo.Probe(static_ens.neurons,synapse=0.1)
    ssim_probe = nengo.Probe(ssim_node,synapse=0.1)
    activity_sim_probe = nengo.Probe(activity_sim_node,synapse=0.1)

In [11]:
sim = nengo.Simulator(model)

In [12]:
sim.run(5)


Simulation finished in 0:06:05.                                                 

The following is not part of the brain model, it is used to view the output for the ensemble

Since it's probing the neurons themselves, the output must be transformed from neuron activity to visual image


In [31]:
%matplotlib inline
#Graph of probe output
plt.plot(sim.trange(), sim.data[ssim_probe], 'k', label="SSIM")
plt.legend()
plt.show()



In [13]:
#Graph of probe output
plt.plot(sim.trange(), sim.data[activity_sim_probe], 'k', label="Activity Similarity")
plt.legend()
plt.show()

In [38]:
#Turn probe activity to img
output_acts = []
for act in sim.data[probe]:
    output_acts.append(np.dot(act,activity_to_img_weights))

In [83]:
'''Animation for Probe output'''
fig = plt.figure()


def updatefig(i):
    im = pylab.imshow(np.reshape(output_acts[i],(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'),animated=True)
    
    return im,

ani = animation.FuncAnimation(fig, updatefig, interval=0.1, blit=True)
plt.show()


---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-83-d4eda138fd3f> in <module>()
     12 
     13 ani = animation.FuncAnimation(fig, updatefig, interval=0.1, blit=True)
---> 14 plt.show()

C:\Python27\lib\site-packages\matplotlib\pyplot.pyc in show(*args, **kw)
    242     """
    243     global _show
--> 244     return _show(*args, **kw)
    245 
    246 

C:\Python27\lib\site-packages\matplotlib\backend_bases.pyc in __call__(self, block)
    190 
    191         if not is_interactive() or get_backend() == 'WebAgg':
--> 192             self.mainloop()
    193 
    194     def mainloop(self):

C:\Python27\lib\site-packages\matplotlib\backends\backend_tkagg.pyc in mainloop(self)
     72 class Show(ShowBase):
     73     def mainloop(self):
---> 74         Tk.mainloop()
     75 
     76 show = Show()

C:\Python27\lib\lib-tk\Tkinter.pyc in mainloop(n)
    412 def mainloop(n=0):
    413     """Run the main loop of Tcl."""
--> 414     _default_root.tk.mainloop(n)
    415 
    416 getint = int

AttributeError: 'NoneType' object has no attribute 'tk'

In [39]:
#ouput_acts = sim.data[probe]

plt.subplot(261)
plt.title("100")
pylab.imshow(np.reshape(output_acts[100],(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'))
plt.subplot(262)
plt.title("500")
pylab.imshow(np.reshape(output_acts[500],(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'))
plt.subplot(263)
plt.title("1000")
pylab.imshow(np.reshape(output_acts[1000],(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'))
plt.subplot(264)
plt.title("1500")
pylab.imshow(np.reshape(output_acts[1500],(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'))
plt.subplot(265)
plt.title("2000")
pylab.imshow(np.reshape(output_acts[2000],(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'))
plt.subplot(266)
plt.title("2500")
pylab.imshow(np.reshape(output_acts[2500],(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'))
plt.subplot(267)
plt.title("3000")
pylab.imshow(np.reshape(output_acts[3000],(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'))
plt.subplot(268)
plt.title("3500")
pylab.imshow(np.reshape(output_acts[3500],(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'))
plt.subplot(269)
plt.title("4000")
pylab.imshow(np.reshape(output_acts[4000],(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'))
plt.subplot(2,6,10)
plt.title("4500")
pylab.imshow(np.reshape(output_acts[4500],(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'))
plt.subplot(2,6,11)
plt.title("5000")
pylab.imshow(np.reshape(output_acts[4999],(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'))


plt.show()


Pickle the probe's output if it takes a long time to run


In [ ]:
#The filename includes the number of neurons and which digit is being rotated
filename = "mental_rotation_output_ONE_"  + str(n_hid) + ".p"
cPickle.dump(sim.data[probe], open( filename , "wb" ) )

Testing


In [171]:
testing = np.dot(ONE,np.dot(label_weights,activity_to_img_weights))
testing = output_acts[300]
plt.subplot(131)
pylab.imshow(np.reshape(testing,(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'))

#Get image
#testing = np.dot(ONE,np.dot(label_weights,activity_to_img_weights))
#noise = np.random.random([28,28]).ravel()
testing = node_func(0,testing)

plt.subplot(132)
pylab.imshow(np.reshape(testing,(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'))


#Get activity of image
_, testing_act = nengo.utils.ensemble.tuning_curves(ens, sim, inputs=testing)

#Get encoder outputs
testing_filter = np.dot(testing_act,rotated_counter_after_encoder_weights_filter)

#Get activities
testing_filter = ens.neuron_type.rates(testing_filter, sim.data[ens].gain, sim.data[ens].bias)

for i in range(5):
    testing_filter = np.dot(testing_filter,rotated_counter_after_encoder_weights_filter)
    testing_filter = ens.neuron_type.rates(testing_filter, sim.data[ens].gain, sim.data[ens].bias)
    testing_filter = np.dot(testing_filter,activity_to_img_weights)
    testing_filter = node_func(0,testing_filter)
    _, testing_filter = nengo.utils.ensemble.tuning_curves(ens, sim, inputs=testing_filter)


#testing_rotate = np.dot(testing_rotate,rotation_weights)

testing_filter = np.dot(testing_filter,activity_to_img_weights)

plt.subplot(133)
pylab.imshow(np.reshape(testing_filter,(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'))

plt.show()

In [ ]:
plt.subplot(121)
pylab.imshow(np.reshape(X_train[0],(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'))

#Get activity of image
_, testing_act = nengo.utils.ensemble.tuning_curves(ens, sim, inputs=X_train[0])

testing_rotate = np.dot(testing_act,activity_to_img_weights)

plt.subplot(122)
pylab.imshow(np.reshape(testing_rotate,(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'))

plt.show()

Just for fun


In [ ]:
letterO = np.dot(ZERO,np.dot(label_weights,activity_to_img_weights))
plt.subplot(161)
pylab.imshow(np.reshape(letterO,(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'))

letterL = np.dot(SEVEN,label_weights)
for _ in range(30):
    letterL = np.dot(letterL,rotation_weights)
letterL = np.dot(letterL,activity_to_img_weights)
plt.subplot(162)
pylab.imshow(np.reshape(letterL,(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'))

letterI = np.dot(ONE,np.dot(label_weights,activity_to_img_weights))
plt.subplot(163)
pylab.imshow(np.reshape(letterI,(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'))
plt.subplot(165)
pylab.imshow(np.reshape(letterI,(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'))

letterV = np.dot(SEVEN,label_weights)
for _ in range(40):
    letterV = np.dot(letterV,rotation_weights)
letterV = np.dot(letterV,activity_to_img_weights)
plt.subplot(164)
pylab.imshow(np.reshape(letterV,(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'))

letterA = np.dot(SEVEN,label_weights)
for _ in range(10):
    letterA = np.dot(letterA,rotation_weights)
letterA = np.dot(letterA,activity_to_img_weights)
plt.subplot(166)
pylab.imshow(np.reshape(letterA,(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'))

plt.show()