In [1]:
import nengo
import numpy as np
import cPickle
from nengo_extras.data import load_mnist
from nengo_extras.vision import Gabor, Mask
from matplotlib import pylab
import matplotlib.pyplot as plt
import matplotlib.animation as animation
Load the MNIST database
In [2]:
# --- load the data
img_rows, img_cols = 28, 28
(X_train, y_train), (X_test, y_test) = load_mnist()
X_train = 2 * X_train - 1 # normalize to -1 to 1
X_test = 2 * X_test - 1 # normalize to -1 to 1
Each digit is represented by a one hot vector where the index of the 1 represents the number
In [3]:
temp = np.diag([1]*10)
ZERO = temp[0]
ONE = temp[1]
TWO = temp[2]
THREE= temp[3]
FOUR = temp[4]
FIVE = temp[5]
SIX = temp[6]
SEVEN =temp[7]
EIGHT= temp[8]
NINE = temp[9]
labels =[ZERO,ONE,TWO,THREE,FOUR,FIVE,SIX,SEVEN,EIGHT,NINE]
dim =28
Load the saved weight matrices that were created by trainging the model
In [4]:
label_weights = cPickle.load(open("label_weights1000.p", "rb"))
activity_to_img_weights = cPickle.load(open("activity_to_img_weights_scale1000.p", "rb"))
scale_up_after_encoder_weights = cPickle.load(open("scale_up_after_encoder_weights1000.p", "r"))
scale_down_after_encoder_weights = cPickle.load(open("scale_down_after_encoder_weights1000.p", "r"))
scale_up_weights = cPickle.load(open("scale_up_weights1000.p","rb"))
scale_down_weights = cPickle.load(open("scale_down_weights1000.p","rb"))
In [5]:
rng = np.random.RandomState(9)
n_hid = 1000
model = nengo.Network(seed=3)
with model:
#Stimulus only shows for brief period of time
stim = nengo.Node(lambda t: ZERO if t < 0.1 else 0) #nengo.processes.PresentInput(labels,1))#
ens_params = dict(
eval_points=X_train,
neuron_type=nengo.LIF(), #Why not use LIF?
intercepts=nengo.dists.Choice([-0.5]),
max_rates=nengo.dists.Choice([100]),
)
# linear filter used for edge detection as encoders, more plausible for human visual system
encoders = Gabor().generate(n_hid, (11, 11), rng=rng)
encoders = Mask((28, 28)).populate(encoders, rng=rng, flatten=True)
ens = nengo.Ensemble(n_hid, dim**2, seed=3, encoders=encoders, **ens_params)
#Recurrent connection on the neurons of the ensemble to perform the rotation
nengo.Connection(ens.neurons, ens.neurons, transform = scale_down_after_encoder_weights.T, synapse=0.1)
#Connect stimulus to ensemble, transform using learned weight matrices
nengo.Connection(stim, ens, transform = np.dot(label_weights,activity_to_img_weights).T, synapse=0.1)
#Collect output, use synapse for smoothing
probe = nengo.Probe(ens.neurons,synapse=0.1)
In [6]:
sim = nengo.Simulator(model)
In [7]:
sim.run(5)
In [10]:
'''Animation for Probe output'''
fig = plt.figure()
output_acts = []
for act in sim.data[probe]:
output_acts.append(np.dot(act,activity_to_img_weights))
def updatefig(i):
im = pylab.imshow(np.reshape(output_acts[i],(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'),animated=True)
return im,
ani = animation.FuncAnimation(fig, updatefig, interval=0.1, blit=True)
plt.show()
Pickle the probe's output if it takes a long time to run
In [ ]:
#The filename includes the number of neurons and which digit is being rotated
filename = "mental_scaling_output_ZERO_" + str(n_hid) + ".p"
cPickle.dump(sim.data[probe], open( filename , "wb" ) )
In [40]:
testing = np.dot(ZERO,np.dot(label_weights,activity_to_img_weights))
plt.subplot(121)
pylab.imshow(np.reshape(testing,(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'))
#Get image
testing = np.dot(ZERO,np.dot(label_weights,activity_to_img_weights))
#Get activity of image
_, testing_act = nengo.utils.ensemble.tuning_curves(ens, sim, inputs=testing)
#Get rotated encoder outputs
testing_scale = np.dot(testing_act,scale_down_after_encoder_weights)
#Get activities
testing_scale = ens.neuron_type.rates(testing_scale, sim.data[ens].gain, sim.data[ens].bias)
for i in range(2):
testing_scale = np.dot(testing_scale,scale_down_after_encoder_weights)
testing_scale = ens.neuron_type.rates(testing_scale, sim.data[ens].gain, sim.data[ens].bias)
#testing_rotate = np.dot(testing_rotate,rotation_weights)
testing_scale = np.dot(testing_scale,activity_to_img_weights)
plt.subplot(122)
pylab.imshow(np.reshape(testing_scale,(dim, dim), 'F').T, cmap=plt.get_cmap('Greys_r'))
plt.show()