In [1]:
%matplotlib inline

import tensorflow as tf
import numpy as np

MNIST Example with Partial Graph Evaluation

This example illustrates the use of partialflow for training a neural network with heavy memory consumption on a GPU with limited memory resources. To keep things simple, we will train a convolutional network on MNIST and use a very large batch size to make the training process memory-intensive.

First we prepare the MNIST dataset and build a tensorflow input queue with a batch size of 7500:


In [2]:
# load MNIST data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

train_images = np.reshape(mnist.train.images, [-1, 28, 28, 1])
train_labels = mnist.train.labels

test_images = np.reshape(mnist.test.images, [-1, 28, 28, 1])
test_labels = mnist.test.labels

# training input queue with large batch size
batch_size = 7500
image, label = tf.train.slice_input_producer([train_images, train_labels])
image_batch, label_batch = tf.train.batch([image, label], batch_size=batch_size)


Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz

Construct Residual Network

partialflow allows us to split a tensorflow graph into several sections which can then be trained separately to lower the memory consumption. This means that the training graph of each section on its own has to fit into GPU memory, whereas the full network's training graph may not.

The graph sections are managed by a GraphSectionManager that orchestrates the data flow between the graph sections during training:


In [3]:
from partialflow import GraphSectionManager

sm = GraphSectionManager()

We can now use the GraphSectionManager to create new sections in which we define our network. partialflow automatically analyzes the tensorflow graph and keeps track of tensors flowing across section borders and variables defined in sections.

In the following, we define our CNN in four sections. This is mainly done for illustrative purposes since two sections might already suffice, depending on your GPU memory. We added some tf.Print statements to make tensorflow log forward passes for each section.


In [4]:
from BasicNets import BatchnormNet

# flag for batch normalization layers
is_training = tf.placeholder(name='is_training', shape=[], dtype=tf.bool)
net = BatchnormNet(is_training, image_batch)

# first network section with initial convolution and three residual blocks
with sm.new_section() as sec0:
    with tf.variable_scope('initial_conv'):
        stream = net.add_conv(net._inputs, n_filters=16)
        stream = tf.Print(stream, [stream], 'Forward pass over section 0')
        stream = net.add_bn(stream)
        stream = tf.nn.relu(stream)
    
    with tf.variable_scope('scale0'):
        for i in range(3):
            with tf.variable_scope('block_%d' % i):
                stream = net.res_block(stream)

                
# second network section strided convolution to decrease the input resolution
with sm.new_section() as sec1:
    with tf.variable_scope('scale1'):
        stream = tf.Print(stream, [stream], 'Forward pass over section 1')
        stream = net.res_block(stream, filters_factor=2, first_stride=2)
        for i in range(2):
            with tf.variable_scope('block_%d' % i):
                stream = net.res_block(stream)

# third network section
with sm.new_section() as sec2:
    with tf.variable_scope('scale2'):
        stream = tf.Print(stream, [stream], 'Forward pass over section 2')
        stream = net.res_block(stream, filters_factor=2, first_stride=2)
        for i in range(4):
            with tf.variable_scope('block_%d' % i):
                stream = net.res_block(stream)
        
# fourth network section with final pooling and cross-entropy loss
with sm.new_section() as sec3:
    with tf.variable_scope('final_pool'):
        stream = tf.Print(stream, [stream], 'Forward pass over section 3')
        # global average pooling over image dimensions
        stream = tf.reduce_mean(stream, axis=2)
        stream = tf.reduce_mean(stream, axis=1)
        
        # final conv for classification
        stream = net.add_fc(stream, out_dims=10)
    
    with tf.variable_scope('loss'):
        loss = tf.nn.softmax_cross_entropy_with_logits(stream, label_batch)
        loss = tf.reduce_mean(loss)

Note that the loss is defined inside a graph section. This is necessary to ensure that the image and label batches are cached and reused during forward and backward passes over the network. If the loss were defined outside a section, the input queues might be evaluated multiple times which leads to incorrect gradients being propagated.

Add training operations and prepare training

In order to construct the training graph for our network, we ask the GraphSectionManager to create training operations for each section. This can be done automatically as shown here, or by handing it a list of (possibly preprocessed) gradients as returned by opt.compute_gradients.

The verbose parameter lets the manager add tf.Print statements into the gradient computation in order to log backward passes.


In [5]:
opt = tf.train.AdamOptimizer(learning_rate=0.0001)

sm.add_training_ops(opt, loss, var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), verbose=True)

Finally, the GraphSectionManager needs to analyze the data flows in forward and backward passes across the graph sections:


In [6]:
sm.prepare_training()

At this point we may perform some sanity checks to vaildate that the right tensors are cached and fed into training runs of different sections. For example, we expect the backward pass of section sec2 to depend on some output of sec1 as well as gradients computed in sec3:


In [7]:
sec2.get_tensors_to_feed()


Out[7]:
{<tf.Tensor 'gradients/graph_section_3/final_pool/Mean_grad/truediv:0' shape=(7500, 7, 7, 64) dtype=float32>,
 <tf.Tensor 'graph_section_1/scale1/block_1/add:0' shape=(7500, 14, 14, 32) dtype=float32>,
 <tf.Tensor 'is_training:0' shape=() dtype=bool>}

Consequently, the corresponding gradient tensors have to be cached during the backward pass of sec3:


In [8]:
sec3.get_tensors_to_cache()


Out[8]:
{<tf.Tensor 'gradients/graph_section_3/final_pool/Mean_grad/truediv:0' shape=(7500, 7, 7, 64) dtype=float32>}

Run Forward and Backward Passes

We can now open a new session and initialize our graph:


In [9]:
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
_ = tf.train.start_queue_runners(sess=sess)

As usual, a simple forward pass ignoring the sections can be performed using session.run.


In [10]:
sess.run(loss, feed_dict={is_training: True})


Out[10]:
11.650328

This should log a single forward pass for each section into the console running this notebook:

Forward pass over section 0[[[[0 0 0]]]...]
Forward pass over section 1[[[[0.25482464 0.54249996 0.15713426]]]...]
Forward pass over section 2[[[[1.3474643 0.62452459 0.14982516]]]...]
Forward pass over section 3[[[[0.52292633 -0.39113081 0.74775648]]]...]

Training

Since intermediate results need to be cached for the backward pass, training operations need to be run using GraphSectionManager. The run_full_cycle method will run a forward pass, cache intermediate results as needed, and perform a backward pass over the training operations.

Forward passes are not performed section-wise, because tensorflow optimizes memory consumption by dropping intermediate results anyway. Hence the full forward pass graph is assumed to fit into GPU memory. In contrast, backward passes are performed section-wise. run_full_cycle takes care of evaluating the graph elements in fetches during the right phases of this procedure.

The following should log a full forward pass followed by interleaved forward and backward passes for each section. Note that the basic_feed parameter is used analogous to feed_dict in session.run:


In [11]:
sm.run_full_cycle(sess, fetches=loss, basic_feed={is_training:True})


Out[11]:
11.725137

... which should log something like

Forward pass over section 0[[[[0 0 0]]]...]
Forward pass over section 1[[[[0.25558376 0.54339874 0.15626775]]]...]
Forward pass over section 2[[[[1.3982055 0.59655607 0.18760961]]]...]
Forward pass over section 3[[[[1.2530568 -1.3083258 -0.73674989]]]...]
Running backward pass on section 3[-0.099787384 -0.11186664 -0.0903545...]
Forward pass over section 2[[[[1.3982055 0.59655607 0.18760961]]]...]
Running backward pass on section 2[[[[-0.55458224 0.38391361 -0.357202]]]...]
Forward pass over section 1[[[[0.25558376 0.54339874 0.15626775]]]...]
Running backward pass on section 1[[[[-0.0038170468 0.001159993 0.0018510161]]]...]
Forward pass over section 0[[[[0 0 0]]]...]
Running backward pass on section 0[-0.098705873 0.026503615 -0.0099251084...]