In [1]:
%matplotlib inline
from pylab import *
Let us implement a Hopfield network using images from the MNIST dataset as patterns.
In [2]:
#### Download the dataset
# Get the script from internet
! wget https://raw.githubusercontent.com/sorki/python-mnist/master/get_data.sh > /dev/null 2>&1
# Run it to dovnload all files in a local dir named 'data'
! bash get_data.sh >/dev/null 2>&1
# We do not need the script anymore, remove it
! rm get_data.sh* > /dev/null 2>&1
# Initialize the dataset variables
%run utils
We now fill a array with all parameters. We only need few samples, we take them from the training set. We take samples 2 and 5, representing respectively a '4' and a '2'
In [3]:
# Take two rows
patterns = array(mndata.train_images)[[2,5],]
labels = array(mndata.train_labels)[[2,5],]
# We need only the sign (transform to binary input)
patterns = sign(patterns/255.0 - 0.5)
# Set the number of patterns (two in out case)
n_patterns = patterns.shape[0]
# Number of units of the network
n = img_side*img_side
Let us visualize our two patterns:
In [4]:
fig = figure(figsize = (8, 4))
for i in xrange(n_patterns):
plot_img( to_mat(patterns[i]),
fig, i+1, windows = 2 )
In [5]:
# Initialize weights to zero values
W = zeros([n,n])
# Accumulate outer products
for pattern in patterns :
W += outer(pattern, pattern)
# Divide times the number of patterns
W /= float(n_patterns)
# Exclude the autoconnections
W *= 1.0 - eye(n, n)
In [6]:
# Number of timesteps
stime = 1000
# Number of samples to store as long
# as spreading goes on
samples = 100
# store data at each sampling interval
sample_interval = stime/samples
# Init the stories of spreading as a zero array,
# we will fill it in at each timestep and we will
# plot it at the end
store_images = zeros([n_patterns, n, samples])
# Init the stories of energy as a zero array,
# we will fill it in at each timestep and we will
# plot it at the end
store_energy = zeros([n_patterns, samples])
# We simulate two iterations, each one starting
# with a coorupted version of one of our two patterns
for target_index in xrange(n_patterns) :
# Copy the original pattern
target = patterns[target_index]
x = target.copy()
# Then modify the second half of the image
# putting random binary values
x[(n/2):] = sign(randn(n/2))
# During the iterations we need to peek
# one unit at random. Thus we must prepare
# a random sequence of indices:
# we get the sequence of indices
# of the network units
x_indices = arange(n)
# and we shuffle it
shuffle(x_indices)
# the iterations
for t in xrange(stime) :
# Get the current index browsing
# the random sequence
current_x = x_indices[t%n]
# Activation of a unit
x[current_x] = sign(dot(W[current_x,:], x))
# Store current activations
if stime%sample_interval == 0 :
# Energy of the current state of the network
store_energy[target_index, t/sample_interval] = -0.5*dot(x, dot(W, x))
# array containing samples of network activation
store_images[target_index,:,t/sample_interval] = x
Here you can see two animations showing the network that is initially activated with one of the two patterns. The initial activation is corrupted with a lot of noise so that the bottom half of the figure is completelly obscured. The network moves from this initial activation to the correct attractor state (the original uncorrupted figure). During this process the energy of the network lowers untill it reaches a steady state.
We use the matplotlib.animation package for animations and the gridspec class to customize the layout of subplots.
In [7]:
# The matplotlib object to do animations
from matplotlib import animation
# This grid allows to layout subplots in a more
# flexible way
import matplotlib.gridspec as gridspec
To plot the two animations we need a function to initialize a figure with three plots: the first showing the target digit, the second showing the current activity of the network and the third showing the sum of squared errors.
In [8]:
def init_figure(fig) :
# Init the grid and the figure
gs = gridspec.GridSpec(6, 20)
#-------------------------------------------------
# Plot 1 - plot the target digit
# Create subplot
ax1 = fig.add_subplot(gs[:4,:4])
title("target")
# Create the imshow and save the handler
im_target = ax1.imshow(to_mat(patterns[0]),
interpolation = 'none',
aspect = 'auto',
cmap = cm.binary)
axis('off')
#-------------------------------------------------
# Plot 2 - plot the current state of the network
# Create subplot
ax2 = fig.add_subplot(gs[:4,6:10])
title("recalling")
# Create the imshow and save the handler
im_activation = ax2.imshow(to_mat(store_images[0,:,0]),
interpolation = 'none',
aspect = 'auto',
cmap = cm.binary)
axis('off')
#-------------------------------------------------
# Plot 3 - plot the current history of energy
# Create subplot
ax3 = fig.add_subplot(gs[:4,12:])
title("Energy")
# Create the line plot and save the handler
im_energy, = ax3.plot(store_energy[0,])
# Only bottom-left axes - no tics
ax3.spines['top'].set_visible(False)
ax3.spines['right'].set_visible(False)
ax3.set_xticks([])
ax3.set_yticks([])
# return plot handlers
return im_target, im_activation, im_energy
We also need one another function that updates the figure at each animation timestep with a new sample
In [9]:
# Updates images at each frame of the animation
# data : list of tuples Each row contains the
# arguments of update for
# a frame
# returns : tuple The handlers of the
# images
def update(data) :
# unpack plot handlers and data
im_A, im_B, im_C, A, B, C = data
# Update data of plot 1, plot 2 and 3
im_A.set_array(to_mat(A))
im_B.set_array(to_mat(B))
im_C.set_data(arange( len(C)), C)
# return plot handlers
return im_A, im_B, im_C
Finally we use the FuncAnimation class. We first build a data list where each row is a tuple containing plot handlers and data do for plot updates..
In [10]:
for target_index in xrange(n_patterns):
# Init the figure
fig = figure(figsize=(8, 3.5))
im_target, im_activation, im_energy = init_figure(fig)
# Build the sequence of update arguments.
# each row of the list contains:
# 1 the target plot handler
# 2 the activation plot handler
# 3 the energy plot handler
# 4 the target update data
# 5 the activation update data
# 6 the energy update data
data = [(
im_target,
im_activation,
im_energy,
patterns[target_index],
squeeze(store_images[target_index,:,t]),
store_energy[target_index, :t] )
for t in xrange(samples ) ]
# Create and render the animation
anim = animation.FuncAnimation(fig, func = update, frames = data )
# save it to file
anim.save("mnist-hopfield_{:d}.gif".format(labels[target_index]),
fps = 10, writer='imagemagick')
Next cell is just for styling
In [11]:
from IPython.core.display import HTML
def css_styling():
styles = open("../style/ipybn.css", "r").read()
return HTML(styles)
css_styling()
Out[11]: