Using the trained weights in an ensemble of neurons

  • On the function points branch of nengo
  • On the vision branch of nengo_extras

In [1]:
import nengo
import numpy as np
import cPickle
from nengo_extras.data import load_mnist
from nengo_extras.vision import Gabor, Mask
from matplotlib import pylab
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from scipy import linalg

Load the MNIST database


In [2]:
# --- load the data
img_rows, img_cols = 28, 28

(X_train, y_train), (X_test, y_test) = load_mnist()

X_train = 2 * X_train - 1  # normalize to -1 to 1
X_test = 2 * X_test - 1  # normalize to -1 to 1

Each digit is represented by a one hot vector where the index of the 1 represents the number


In [3]:
temp = np.diag([1]*10)

ZERO = temp[0]
ONE =  temp[1]
TWO =  temp[2]
THREE= temp[3]
FOUR = temp[4]
FIVE = temp[5]
SIX =  temp[6]
SEVEN =temp[7]
EIGHT= temp[8]
NINE = temp[9]

labels =[ZERO,ONE,TWO,THREE,FOUR,FIVE,SIX,SEVEN,EIGHT,NINE]

dim =28

Load the saved weight matrices that were created by training the model


In [4]:
label_weights = cPickle.load(open("label_weights5000.p", "rb"))
activity_to_img_weights = cPickle.load(open("activity_to_img_weights5000.p", "rb"))
rotated_clockwise_after_encoder_weights =  cPickle.load(open("rotated_after_encoder_weights_clockwise5000.p", "r"))
rotated_counter_after_encoder_weights =  cPickle.load(open("rotated_after_encoder_weights5000.p", "r"))

#scale_up_after_encoder_weights = cPickle.load(open("scale_up_after_encoder_weights1000.p","r"))
#scale_down_after_encoder_weights = cPickle.load(open("scale_down_after_encoder_weights1000.p","r"))
#translate_up_after_encoder_weights = cPickle.load(open("translate_up_after_encoder_weights1000.p","r"))
#translate_down_after_encoder_weights = cPickle.load(open("translate_down_after_encoder_weights1000.p","r"))
#translate_left_after_encoder_weights = cPickle.load(open("translate_left_after_encoder_weights1000.p","r"))
#translate_right_after_encoder_weights = cPickle.load(open("translate_right_after_encoder_weights1000.p","r"))




#identity_after_encoder_weights = cPickle.load(open("identity_after_encoder_weights1000.p","r"))

Functions to perform the inhibition of each ensemble


In [24]:
#A value of zero gives no inhibition

def inhibit_rotate_clockwise(t):
    if t < 1:
        return dim**2
    else:
        return 0
    
def inhibit_rotate_counter(t):
    if t < 1:
        return 0
    else:
        return dim**2
    
def inhibit_identity(t):
    if t < 1:
        return dim**2
    else:
        return dim**2
    
def inhibit_scale_up(t):
    return dim**2
def inhibit_scale_down(t):
    return dim**2

def inhibit_translate_up(t):
    return dim**2
def inhibit_translate_down(t):
    return dim**2
def inhibit_translate_left(t):
    return dim**2
def inhibit_translate_right(t):
    return dim**2

The network where the mental imagery and rotation occurs

  • The state, seed and ensemble parameters (including encoders) must all be the same for the saved weight matrices to work
  • The number of neurons (n_hid) must be the same as was used for training
  • The input must be shown for a short period of time to be able to view the rotation
  • The recurrent connection must be from the neurons because the weight matices were trained on the neuron activities

In [25]:
def add_manipulation(main_ens,weights,inhibition_func):
    #create ensemble for manipulation
    ens_manipulation = nengo.Ensemble(n_hid,dim**2,seed=3,encoders=encoders, **ens_params)
    #create node for inhibition
    inhib_manipulation = nengo.Node(inhibition_func)
    #Connect the main ensemble to each manipulation ensemble and back with appropriate transformation
    nengo.Connection(main_ens.neurons, ens_manipulation.neurons, transform = weights.T, synapse=0.1)
    nengo.Connection(ens_manipulation.neurons, main_ens.neurons, transform = weights.T,synapse = 0.1)
    #connect inhibition
    nengo.Connection(inhib_manipulation, ens_manipulation.neurons, transform=[[-1]] * n_hid)
    
    #return ens_manipulation,inhib_manipulation

In [26]:
rng = np.random.RandomState(9)
n_hid = 1000
model = nengo.Network(seed=3)
with model:
    #Stimulus only shows for brief period of time
    stim = nengo.Node(lambda t: ONE if t < 0.1 else 0) #nengo.processes.PresentInput(labels,1))#
    
    ens_params = dict(
        eval_points=X_train,
        neuron_type=nengo.LIF(), #Why not use LIF?
        intercepts=nengo.dists.Choice([-0.5]),
        max_rates=nengo.dists.Choice([100]),
        )
        
    
    # linear filter used for edge detection as encoders, more plausible for human visual system
    encoders = Gabor().generate(n_hid, (11, 11), rng=rng)
    encoders = Mask((28, 28)).populate(encoders, rng=rng, flatten=True)


    #Ensemble that represents the image with different transformations applied to it
    ens = nengo.Ensemble(n_hid, dim**2, seed=3, encoders=encoders, **ens_params)
    

    #Connect stimulus to ensemble, transform using learned weight matrices
    nengo.Connection(stim, ens, transform = np.dot(label_weights,activity_to_img_weights).T)
    
    #Recurrent connection on the neurons of the ensemble to perform the rotation
    #nengo.Connection(ens.neurons, ens.neurons, transform = rotated_counter_after_encoder_weights.T, synapse=0.1)      

    
    #add_manipulation(ens,rotated_clockwise_after_encoder_weights, inhibit_rotate_clockwise)
    add_manipulation(ens,rotated_counter_after_encoder_weights, inhibit_rotate_counter)
    add_manipulation(ens,scale_up_after_encoder_weights, inhibit_scale_up)
    #add_manipulation(ens,scale_down_after_encoder_weights, inhibit_scale_down)
    #add_manipulation(ens,translate_up_after_encoder_weights, inhibit_translate_up)
    #add_manipulation(ens,translate_down_after_encoder_weights, inhibit_translate_down)
    #add_manipulation(ens,translate_left_after_encoder_weights, inhibit_translate_left)
    #add_manipulation(ens,translate_right_after_encoder_weights, inhibit_translate_right)
    
    

    
    #Collect output, use synapse for smoothing
    probe = nengo.Probe(ens.neurons,synapse=0.1)

In [27]:
sim = nengo.Simulator(model)

In [28]:
sim.run(5)


Simulation finished in 0:01:05.                                                 

The following is not part of the brain model, it is used to view the output for the ensemble

Since it's probing the neurons themselves, the output must be transformed from neuron activity to visual image


In [29]:
'''Animation for Probe output'''
fig = plt.figure()

output_acts = []
for act in sim.data[probe]:
    output_acts.append(np.dot(act,activity_to_img_weights))

def updatefig(i):
    im = pylab.imshow(np.reshape(output_acts[i],(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'),animated=True)
    
    return im,

ani = animation.FuncAnimation(fig, updatefig, interval=100, blit=True)
plt.show()

In [30]:
print(len(sim.data[probe]))

plt.subplot(161)
plt.title("100")
pylab.imshow(np.reshape(output_acts[100],(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'))
plt.subplot(162)
plt.title("500")
pylab.imshow(np.reshape(output_acts[500],(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'))
plt.subplot(163)
plt.title("1000")
pylab.imshow(np.reshape(output_acts[1000],(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'))
plt.subplot(164)
plt.title("1500")
pylab.imshow(np.reshape(output_acts[1500],(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'))
plt.subplot(165)
plt.title("2000")
pylab.imshow(np.reshape(output_acts[2000],(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'))
plt.subplot(166)
plt.title("2500")
pylab.imshow(np.reshape(output_acts[2500],(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'))

plt.show()


5000

Pickle the probe's output if it takes a long time to run


In [ ]:
#The filename includes the number of neurons and which digit is being rotated
filename = "mental_rotation_output_ONE_"  + str(n_hid) + ".p"
cPickle.dump(sim.data[probe], open( filename , "wb" ) )

Testing


In [ ]:
testing = np.dot(ONE,np.dot(label_weights,activity_to_img_weights))
plt.subplot(121)
pylab.imshow(np.reshape(testing,(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'))

#Get image
testing = np.dot(ONE,np.dot(label_weights,activity_to_img_weights))


#Get activity of image
_, testing_act = nengo.utils.ensemble.tuning_curves(ens, sim, inputs=testing)

#Get rotated encoder outputs
testing_rotate = np.dot(testing_act,rotated_after_encoder_weights)

#Get activities
testing_rotate = ens.neuron_type.rates(testing_rotate, sim.data[ens].gain, sim.data[ens].bias)

for i in range(5):
    testing_rotate = np.dot(testing_rotate,rotated_after_encoder_weights)
    testing_rotate = ens.neuron_type.rates(testing_rotate, sim.data[ens].gain, sim.data[ens].bias)


#testing_rotate = np.dot(testing_rotate,rotation_weights)

testing_rotate = np.dot(testing_rotate,activity_to_img_weights)

plt.subplot(122)
pylab.imshow(np.reshape(testing_rotate,(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'))

plt.show()

In [ ]:
plt.subplot(121)
pylab.imshow(np.reshape(X_train[0],(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'))

#Get activity of image
_, testing_act = nengo.utils.ensemble.tuning_curves(ens, sim, inputs=X_train[0])

testing_rotate = np.dot(testing_act,activity_to_img_weights)

plt.subplot(122)
pylab.imshow(np.reshape(testing_rotate,(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'))

plt.show()

Just for fun


In [ ]:
letterO = np.dot(ZERO,np.dot(label_weights,activity_to_img_weights))
plt.subplot(161)
pylab.imshow(np.reshape(letterO,(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'))

letterL = np.dot(SEVEN,label_weights)
for _ in range(30):
    letterL = np.dot(letterL,rotation_weights)
letterL = np.dot(letterL,activity_to_img_weights)
plt.subplot(162)
pylab.imshow(np.reshape(letterL,(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'))

letterI = np.dot(ONE,np.dot(label_weights,activity_to_img_weights))
plt.subplot(163)
pylab.imshow(np.reshape(letterI,(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'))
plt.subplot(165)
pylab.imshow(np.reshape(letterI,(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'))

letterV = np.dot(SEVEN,label_weights)
for _ in range(40):
    letterV = np.dot(letterV,rotation_weights)
letterV = np.dot(letterV,activity_to_img_weights)
plt.subplot(164)
pylab.imshow(np.reshape(letterV,(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'))

letterA = np.dot(SEVEN,label_weights)
for _ in range(10):
    letterA = np.dot(letterA,rotation_weights)
letterA = np.dot(letterA,activity_to_img_weights)
plt.subplot(166)
pylab.imshow(np.reshape(letterA,(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'))

plt.show()