In [4]:
import matplotlib.pyplot as plt
#%matplotlib inline
import nengo
import numpy as np
import scipy.ndimage
import matplotlib.animation as animation
from matplotlib import pylab
from PIL import Image
import nengo.spa as spa
import cPickle

from nengo_extras.data import load_mnist
from nengo_extras.vision import Gabor, Mask

#Encode categorical integer features using a one-hot aka one-of-K scheme.
def one_hot(labels, c=None):
    assert labels.ndim == 1
    n = labels.shape[0]
    c = len(np.unique(labels)) if c is None else c
    y = np.zeros((n, c))
    y[np.arange(n), labels] = 1
    return y


rng = np.random.RandomState(9)

In [5]:
# --- load the data
img_rows, img_cols = 28, 28

(X_train, y_train), (X_test, y_test) = load_mnist()

X_train = 2 * X_train - 1  # normalize to -1 to 1
X_test = 2 * X_test - 1  # normalize to -1 to 1

train_targets = one_hot(y_train, 10)
test_targets = one_hot(y_test, 10)

In [6]:
# --- set up network parameters
#Want to encode and decode the image
n_vis = X_train.shape[1]
n_out =  X_train.shape[1]
#number of neurons/dimensions of semantic pointer
n_hid = 1000 #Try with more neurons for more accuracy
#n_hid = 1000

#Want the encoding/decoding done on the training images
ens_params = dict(
    eval_points=X_train,
    neuron_type=nengo.LIFRate(), #Why not use LIF?
    intercepts=nengo.dists.Choice([-0.5]),
    max_rates=nengo.dists.Choice([100]),
    )

#Least-squares solver with L2 regularization.
solver = nengo.solvers.LstsqL2(reg=0.01)
#solver = nengo.solvers.LstsqL2(reg=0.0001)
solver2 = nengo.solvers.LstsqL2(reg=0.01)

#network that 
with nengo.Network(seed=3) as model:
    a = nengo.Ensemble(n_hid, n_vis, seed=3, **ens_params)
    v = nengo.Node(size_in=n_out)
    conn = nengo.Connection(
        a, v, synapse=None,
        eval_points=X_train, function=X_train,#want the same thing out
        solver=solver)
    
    v2 = nengo.Node(size_in=train_targets.shape[1])
    conn2 = nengo.Connection(
        a, v2, synapse=None,
        eval_points=X_train, function=train_targets, #Want to get the labels out
        solver=solver2)
    
    

def get_outs(sim, images):
    _, acts = nengo.utils.ensemble.tuning_curves(a, sim, inputs=images)
    return np.dot(acts, sim.data[conn2].weights.T)

def get_error(sim, images, labels):
    return np.argmax(get_outs(sim, images), axis=1) != labels

def get_labels(sim,images):
    return np.argmax(get_outs(sim, images), axis=1)

#Get the neuron activity of an image or group of images (this is the semantic pointer in this case)
def get_activities(sim, images):
    _, acts = nengo.utils.ensemble.tuning_curves(a, sim, inputs=images)
    return acts

def get_encoder_outputs(sim,images):
    outs = np.dot(images,sim.data[a].encoders.T) #before the neurons Why transpose?
    return outs

In [7]:
'''
#Images to train for rotation of 90 deg
orig_imgs = X_train[:10000].copy()

rotated_imgs =X_train[:10000].copy()
for img in rotated_imgs:
    img[:] = scipy.ndimage.interpolation.rotate(np.reshape(img,(28,28)),90,reshape=False).ravel()


test_imgs = X_test[:1000].copy()
'''  

#Images to train, starting at random orientation
orig_imgs = X_train[:100000].copy()
for img in orig_imgs:
    img[:] = scipy.ndimage.interpolation.rotate(np.reshape(img,(28,28)),(np.random.randint(360)),reshape=False,mode="nearest").ravel()

#Images rotated a fixed amount from the original random orientation
rotated_imgs =orig_imgs.copy()
for img in rotated_imgs:
    img[:] = scipy.ndimage.interpolation.rotate(np.reshape(img,(28,28)),6,reshape=False,mode="nearest").ravel()

    #^encoder outputs
    
#Images not used for training, but for testing (all at random orientations)
test_imgs = X_test[:1000].copy()
for img in test_imgs:
    img[:] = scipy.ndimage.interpolation.rotate(np.reshape(img,(28,28)),(np.random.randint(360)),reshape=False,mode="nearest").ravel()


#Check that rotation is done correctly
plt.subplot(121)
plt.imshow(orig_imgs[5].reshape(28,28),cmap='gray')
plt.subplot(122)
plt.imshow(rotated_imgs[5].reshape(28,28),cmap='gray')


Out[7]:
<matplotlib.image.AxesImage at 0xb98b6d8>

In [8]:
# linear filter used for edge detection as encoders, more plausible for human visual system
encoders = Gabor().generate(n_hid, (11, 11), rng=rng)
encoders = Mask((28, 28)).populate(encoders, rng=rng, flatten=True)
#Set the ensembles encoders to this
a.encoders = encoders

#Check the encoders were correctly made
plt.imshow(encoders[0].reshape(28, 28), vmin=encoders[0].min(), vmax=encoders[0].max(), cmap='gray')


with nengo.Simulator(model) as sim:    
    
    #Neuron activities of different mnist images
    #The semantic pointers
    orig_acts = get_activities(sim,orig_imgs)
    #rotated_acts = get_activities(sim,rotated_imgs)
    #test_acts = get_activities(sim,test_imgs)
    
    #X_test_acts = get_activities(sim,X_test)
    #labels_out = get_outs(sim,X_test)
    
    rotated_after_encoders = get_encoder_outputs(sim,rotated_imgs)
    
    #solvers for a learning rule
    #solver_tranform = nengo.solvers.LstsqL2(reg=1e-8)
    #solver_word = nengo.solvers.LstsqL2(reg=1e-8)
    solver_rotate_encoder = nengo.solvers.LstsqL2(reg=1e-8)
    
    
    #find weight matrix between neuron activity of the original image and the rotated image
    #weights returns a tuple including information about learning process, just want the weight matrix
    #weights,_ = solver_tranform(orig_acts, rotated_acts)
    
    #find weight matrix between labels and neuron activity
    #label_weights,_ = solver_word(labels_out,X_test_acts)
    
    
    rotated_after_encoder_weights,_ = solver_rotate_encoder(orig_acts,rotated_after_encoders)
    
    
#cPickle.dump(rotated_after_encoder_weights, open( "rotated_after_encoder_weights.p", "wb" ) )

In [213]:
print(labels_out.shape)


(10000L, 10L)

In [26]:
#test_targets[i]
'''
ZERO = test_targets[3]
ONE = test_targets[2]
TWO = test_targets[1]
THREE = test_targets[30]
FOUR = test_targets[19]
FIVE = test_targets[8]
SIX = test_targets[11]
SEVEN = test_targets[17]
EIGHT = test_targets[61]
NINE = test_targets[7]
'''

temp = np.diag([1]*10)

ZERO = temp[0]
ONE =  temp[1]
TWO =  temp[2]
THREE= temp[3]
FOUR = temp[4]
FIVE = temp[5]
SIX =  temp[6]
SEVEN =temp[7]
EIGHT= temp[8]
NINE = temp[9]

#Change this to imagine different digits
imagine = EIGHT

#Label to activity
test_activity = np.dot(imagine,label_weights)
#Image decoded 
test_output_img = np.dot(test_activity, sim.data[conn].weights.T)

plt.imshow(test_output_img.reshape(28,28),cmap='gray')
plt.show()

In [6]:
#import cPickle
#cPickle.dump(label_weights, open( "label_weights.p", "wb" ) )
#cPickle.dump(sim.data[conn].weights.T, open( "activity_to_img_weights.p", "wb" ) )
#cPickle.dump(weights, open( "rotation_weights.p", "wb" ) )
#cPickle.dump(rotated_after_encoder_weights, open( "rotated_after_encoder_weights.p", "wb" ) )

In [14]:
i = np.random.randint(1000)

#Activity of the rotated semantic pointer, dot product of activity(semantic) and weight matrix 
test_output_act = np.dot(test_acts[i],weights)

#Image decoded with no rotation
test_output_img_unrot = np.dot(test_acts[i], sim.data[conn].weights.T)
#Image decoded after rotation
test_output_img = np.dot(test_output_act, sim.data[conn].weights.T)
#Image rotated with no neurons
output_img_rot = scipy.ndimage.interpolation.rotate(np.reshape(test_imgs[i],(28,28)),6,reshape=False,mode="nearest").ravel()

#Input image
plt.subplot(141)
plt.imshow(test_imgs[i].reshape(28,28),cmap='gray')
#Decoded image, no rotation
plt.subplot(142)
plt.imshow(test_output_img_unrot.reshape(28,28),cmap='gray')
#Rotated image, no neurons
plt.subplot(143)
plt.imshow(output_img_rot.reshape(28,28),cmap='gray')
#Decoded image after rotation
plt.subplot(144)
plt.imshow(test_output_img.reshape(28,28),cmap='gray')
plt.show()

In [42]:
imagine = THREE
frames=60

#Make a list of the rotated images and add first frame
rot_seq = []
rot_seq.append(np.dot(imagine,label_weights))
test_output_img = np.dot(rot_seq[0], sim.data[conn].weights.T)

#add the rest of the frames, using the previous frame to calculate the current frame
for i in range(1,frames):
    rot_seq.append(np.dot(rot_seq[i-1],weights))
    test_output_img = np.dot(rot_seq[i], sim.data[conn].weights.T)

In [45]:
#Animation of rotation
fig = plt.figure()

def updatefig(i):
    temp = np.dot(rot_seq[i], sim.data[conn].weights.T)
    im = pylab.imshow(np.reshape(temp,(28,28), 'F').T, cmap=plt.get_cmap('Greys_r'),animated=True)
    
    return im,

ani = animation.FuncAnimation(fig, updatefig, interval=100, blit=True)
plt.show()

In [15]:
frames=60

rot_seq = []
rot_seq.append(np.dot(test_acts[i],weights))
test_output_img = np.dot(rot_seq[0], sim.data[conn].weights.T)

#plt.subplot(1,frames,1)
#plt.imshow(test_output_img.reshape(28,28),cmap='gray')

for i in range(1,frames):
    rot_seq.append(np.dot(rot_seq[i-1],weights))
    test_output_img = np.dot(rot_seq[i], sim.data[conn].weights.T)
    #plt.subplot(1,frames,i+1)
    #plt.imshow(test_output_img.reshape(28,28),cmap='gray')

In [28]:
fig = plt.figure()

def updatefig(i):
    temp = np.dot(rot_seq[i], sim.data[conn].weights.T)
    im = pylab.imshow(np.reshape(temp,(28,28), 'F').T, cmap=plt.get_cmap('Greys_r'),animated=True)
    
    return im,

ani = animation.FuncAnimation(fig, updatefig, interval=100, blit=True)
plt.show()

--- Ignore ---


In [59]:
a = Image.open("export.png")
a = np.array(a)


pylab.imshow(a,cmap="gray")
plt.show()

In [ ]:
ordered_imgs = X_train[:10].copy()
#= np.ndarray((10,784),dtype=np.ndarray)
ordered_imgs[0] = X_train[1]
ordered_imgs[0] = X_train[8]
ordered_imgs[0] = X_train[25]
ordered_imgs[0] = X_train[12]
ordered_imgs[0] = X_train[26]
ordered_imgs[0] = X_train[47]
ordered_imgs[0] = X_train[18]
ordered_imgs[0] = X_train[29]
ordered_imgs[0] = X_train[46]
ordered_imgs[0] = X_train[22]

names = np.ndarray((10,10),dtype=np.ndarray)

for i in range(10):
    names[i] = spa.SemanticPointer(10)