Tutorial 1: Training an RBM on MNIST Dataset

In this tutorial, we train a simple RBM on the MNIST dataset and visualize its learned filters.

The Imports

First, we import Tensorflow and numpy packages, as well as the packages we need to visualize the learned filters


In [1]:
import numpy as np
import tensorflow as tf

%matplotlib inline
import matplotlib.pyplot as plt
from IPython import display

#Uncomment the below lines if you didn't install xRBM using pip and want to use the local code instead 
#import sys
#sys.path.append('../')

We import the xrbm.models module, which contains the RBM model class, as well as the xrbm.train module, which contains the CD-k approximation algorithm that we use for training our RBM.


In [2]:
import xrbm.models
import xrbm.train
import xrbm.losses
from xrbm.utils.vizutils import *

Get the Training Data

We use the MNIST dataset that is provided by the Tensorflow package:


In [3]:
from tensorflow.examples.tutorials.mnist import input_data

data_sets = input_data.read_data_sets('MNIST_data', False)
training_data = data_sets.train.images


Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz

Set Up the Parameters

The number of the visible units of the RBM equals to the number of the dimensions of the training data. As each image in MNIST is 28x28, the number of total pixels is 784.

We use 200 hidden units, choose a learning rate of 0.1, and set to train the model for 15 epochs.


In [4]:
num_vis         = training_data[0].shape[0] #=784
num_hid         = 200
learning_rate   = 0.1
batch_size      = 100
training_epochs = 15

Create an RBM model with the parameters

We create an RBM model, and set the number of visible and hidden units. We can also give it a name.


In [5]:
# Let's reset the tensorflow graph in case we want to rerun the code
tf.reset_default_graph()

rbm = xrbm.models.RBM(num_vis=num_vis, num_hid=num_hid, name='rbm_mnist')

We create the mini-batches:


In [6]:
batch_idxs = np.random.permutation(range(len(training_data)))
n_batches  = len(batch_idxs) // batch_size

We create a placeholder for the mini-batch data during training.

We use the CD-k algorithm for training the RBM. For this, we create an instance of the CDApproximator from the xrbm.train module and pass the learning rate to it.

We then define our training op using the CDApproximator's train method, passing the RBM model and the placeholder for the data.

In order to monitor the training process, we calculate the reconstruction cost of the model at each epoch, using the rec_cost_op.


In [7]:
batch_data     = tf.placeholder(tf.float32, shape=(None, num_vis))

cdapproximator = xrbm.train.CDApproximator(learning_rate=learning_rate)
train_op       = cdapproximator.train(rbm, vis_data=batch_data)

reconstructed_data,_,_,_ = rbm.gibbs_sample_vhv(batch_data)
xentropy_rec_cost  = xrbm.losses.cross_entropy(batch_data, reconstructed_data)

Finally, we are ready to run everything and see the results:


In [9]:
# Create figure first so that we use the same one to draw the filters on during the training
fig = plt.figure(figsize=(12,8))

with tf.Session() as sess:    
    sess.run(tf.global_variables_initializer())

    for epoch in range(training_epochs):
        for batch_i in range(n_batches):
            # Get just minibatch amount of data
            idxs_i = batch_idxs[batch_i * batch_size:(batch_i + 1) * batch_size]
            
            # Run the training step
            sess.run(train_op, feed_dict={batch_data: training_data[idxs_i]})
    
        reconstruction_cost = sess.run(xentropy_rec_cost, feed_dict={batch_data: training_data})

        W = rbm.W.eval().transpose()
        filters_grid = create_2d_filters_grid(W, filter_shape=(28,28), grid_size=(10, 20), grid_gap=(1,1))
        
        title = ('Epoch %i / %i | Reconstruction Cost = %f'%
                (epoch+1, training_epochs, reconstruction_cost))
        
        plt.title(title)
        plt.imshow(filters_grid, cmap='gray')
        display.clear_output(wait=True)
        display.display(fig)