In this tutorial, we train a simple RBM on the MNIST dataset and visualize its learned filters.
In [1]:
import numpy as np
import tensorflow as tf
%matplotlib inline
import matplotlib.pyplot as plt
from IPython import display
#Uncomment the below lines if you didn't install xRBM using pip and want to use the local code instead
#import sys
#sys.path.append('../')
We import the xrbm.models
module, which contains the RBM model class, as well as the xrbm.train
module, which contains the CD-k
approximation algorithm that we use for training our RBM.
In [2]:
import xrbm.models
import xrbm.train
import xrbm.losses
from xrbm.utils.vizutils import *
In [3]:
from tensorflow.examples.tutorials.mnist import input_data
data_sets = input_data.read_data_sets('MNIST_data', False)
training_data = data_sets.train.images
In [4]:
num_vis = training_data[0].shape[0] #=784
num_hid = 200
learning_rate = 0.1
batch_size = 100
training_epochs = 15
In [5]:
# Let's reset the tensorflow graph in case we want to rerun the code
tf.reset_default_graph()
rbm = xrbm.models.RBM(num_vis=num_vis, num_hid=num_hid, name='rbm_mnist')
We create the mini-batches:
In [6]:
batch_idxs = np.random.permutation(range(len(training_data)))
n_batches = len(batch_idxs) // batch_size
We create a placeholder for the mini-batch data during training.
We use the CD-k algorithm for training the RBM. For this, we create an instance of the CDApproximator
from the xrbm.train
module and pass the learning rate to it.
We then define our training op using the CDApproximator
's train
method, passing the RBM model and the placeholder for the data.
In order to monitor the training process, we calculate the reconstruction cost of the model at each epoch, using the rec_cost_op
.
In [7]:
batch_data = tf.placeholder(tf.float32, shape=(None, num_vis))
cdapproximator = xrbm.train.CDApproximator(learning_rate=learning_rate)
train_op = cdapproximator.train(rbm, vis_data=batch_data)
reconstructed_data,_,_,_ = rbm.gibbs_sample_vhv(batch_data)
xentropy_rec_cost = xrbm.losses.cross_entropy(batch_data, reconstructed_data)
Finally, we are ready to run everything and see the results:
In [9]:
# Create figure first so that we use the same one to draw the filters on during the training
fig = plt.figure(figsize=(12,8))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for epoch in range(training_epochs):
for batch_i in range(n_batches):
# Get just minibatch amount of data
idxs_i = batch_idxs[batch_i * batch_size:(batch_i + 1) * batch_size]
# Run the training step
sess.run(train_op, feed_dict={batch_data: training_data[idxs_i]})
reconstruction_cost = sess.run(xentropy_rec_cost, feed_dict={batch_data: training_data})
W = rbm.W.eval().transpose()
filters_grid = create_2d_filters_grid(W, filter_shape=(28,28), grid_size=(10, 20), grid_gap=(1,1))
title = ('Epoch %i / %i | Reconstruction Cost = %f'%
(epoch+1, training_epochs, reconstruction_cost))
plt.title(title)
plt.imshow(filters_grid, cmap='gray')
display.clear_output(wait=True)
display.display(fig)