In [1]:
%matplotlib inline
import tensorflow as tf
import numpy as np
This example illustrates the use of partialflow
for training a neural network with heavy memory consumption on a GPU with limited memory resources. To keep things simple, we will train a convolutional network on MNIST and use a very large batch size to make the training process memory-intensive.
First we prepare the MNIST dataset and build a tensorflow input queue with a batch size of 7500:
In [2]:
# load MNIST data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
train_images = np.reshape(mnist.train.images, [-1, 28, 28, 1])
train_labels = mnist.train.labels
test_images = np.reshape(mnist.test.images, [-1, 28, 28, 1])
test_labels = mnist.test.labels
# training input queue with large batch size
batch_size = 7500
image, label = tf.train.slice_input_producer([train_images, train_labels])
image_batch, label_batch = tf.train.batch([image, label], batch_size=batch_size)
partialflow
allows us to split a tensorflow
graph into several sections which can then be trained separately to lower the memory consumption. This means that the training graph of each section on its own has to fit into GPU memory, whereas the full network's training graph may not.
The graph sections are managed by a GraphSectionManager
that orchestrates the data flow between the graph sections during training:
In [3]:
from partialflow import GraphSectionManager
sm = GraphSectionManager()
We can now use the GraphSectionManager
to create new sections in which we define our network. partialflow
automatically analyzes the tensorflow
graph and keeps track of tensors flowing across section borders and variables defined in sections.
In the following, we define our CNN in four sections. This is mainly done for illustrative purposes since two sections might already suffice, depending on your GPU memory. We added some tf.Print
statements to make tensorflow
log forward passes for each section.
In [4]:
from BasicNets import BatchnormNet
# flag for batch normalization layers
is_training = tf.placeholder(name='is_training', shape=[], dtype=tf.bool)
net = BatchnormNet(is_training, image_batch)
# first network section with initial convolution and three residual blocks
with sm.new_section() as sec0:
with tf.variable_scope('initial_conv'):
stream = net.add_conv(net._inputs, n_filters=16)
stream = tf.Print(stream, [stream], 'Forward pass over section 0')
stream = net.add_bn(stream)
stream = tf.nn.relu(stream)
with tf.variable_scope('scale0'):
for i in range(3):
with tf.variable_scope('block_%d' % i):
stream = net.res_block(stream)
# second network section strided convolution to decrease the input resolution
with sm.new_section() as sec1:
with tf.variable_scope('scale1'):
stream = tf.Print(stream, [stream], 'Forward pass over section 1')
stream = net.res_block(stream, filters_factor=2, first_stride=2)
for i in range(2):
with tf.variable_scope('block_%d' % i):
stream = net.res_block(stream)
# third network section
with sm.new_section() as sec2:
with tf.variable_scope('scale2'):
stream = tf.Print(stream, [stream], 'Forward pass over section 2')
stream = net.res_block(stream, filters_factor=2, first_stride=2)
for i in range(4):
with tf.variable_scope('block_%d' % i):
stream = net.res_block(stream)
# fourth network section with final pooling and cross-entropy loss
with sm.new_section() as sec3:
with tf.variable_scope('final_pool'):
stream = tf.Print(stream, [stream], 'Forward pass over section 3')
# global average pooling over image dimensions
stream = tf.reduce_mean(stream, axis=2)
stream = tf.reduce_mean(stream, axis=1)
# final conv for classification
stream = net.add_fc(stream, out_dims=10)
with tf.variable_scope('loss'):
loss = tf.nn.softmax_cross_entropy_with_logits(stream, label_batch)
loss = tf.reduce_mean(loss)
Note that the loss is defined inside a graph section. This is necessary to ensure that the image and label batches are cached and reused during forward and backward passes over the network. If the loss were defined outside a section, the input queues might be evaluated multiple times which leads to incorrect gradients being propagated.
In order to construct the training graph for our network, we ask the GraphSectionManager
to create training operations for each section. This can be done automatically as shown here, or by handing it a list of (possibly preprocessed) gradients as returned by opt.compute_gradients
.
The verbose
parameter lets the manager add tf.Print
statements into the gradient computation in order to log backward passes.
In [5]:
opt = tf.train.AdamOptimizer(learning_rate=0.0001)
sm.add_training_ops(opt, loss, var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), verbose=True)
Finally, the GraphSectionManager
needs to analyze the data flows in forward and backward passes across the graph sections:
In [6]:
sm.prepare_training()
At this point we may perform some sanity checks to vaildate that the right tensors are cached and fed into training runs of different sections. For example, we expect the backward pass of section sec2
to depend on some output of sec1
as well as gradients computed in sec3
:
In [7]:
sec2.get_tensors_to_feed()
Out[7]:
Consequently, the corresponding gradient tensors have to be cached during the backward pass of sec3
:
In [8]:
sec3.get_tensors_to_cache()
Out[8]:
In [9]:
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
_ = tf.train.start_queue_runners(sess=sess)
As usual, a simple forward pass ignoring the sections can be performed using session.run
.
In [10]:
sess.run(loss, feed_dict={is_training: True})
Out[10]:
This should log a single forward pass for each section into the console running this notebook:
Forward pass over section 0[[[[0 0 0]]]...]
Forward pass over section 1[[[[0.25482464 0.54249996 0.15713426]]]...]
Forward pass over section 2[[[[1.3474643 0.62452459 0.14982516]]]...]
Forward pass over section 3[[[[0.52292633 -0.39113081 0.74775648]]]...]
Since intermediate results need to be cached for the backward pass, training operations need to be run using GraphSectionManager
. The run_full_cycle
method will run a forward pass, cache intermediate results as needed, and perform a backward pass over the training operations.
Forward passes are not performed section-wise, because tensorflow
optimizes memory consumption by dropping intermediate results anyway. Hence the full forward pass graph is assumed to fit into GPU memory. In contrast, backward passes are performed section-wise. run_full_cycle
takes care of evaluating the graph elements in fetches
during the right phases of this procedure.
The following should log a full forward pass followed by interleaved forward and backward passes for each section. Note that the basic_feed
parameter is used analogous to feed_dict
in session.run
:
In [11]:
sm.run_full_cycle(sess, fetches=loss, basic_feed={is_training:True})
Out[11]:
... which should log something like
Forward pass over section 0[[[[0 0 0]]]...]
Forward pass over section 1[[[[0.25558376 0.54339874 0.15626775]]]...]
Forward pass over section 2[[[[1.3982055 0.59655607 0.18760961]]]...]
Forward pass over section 3[[[[1.2530568 -1.3083258 -0.73674989]]]...]
Running backward pass on section 3[-0.099787384 -0.11186664 -0.0903545...]
Forward pass over section 2[[[[1.3982055 0.59655607 0.18760961]]]...]
Running backward pass on section 2[[[[-0.55458224 0.38391361 -0.357202]]]...]
Forward pass over section 1[[[[0.25558376 0.54339874 0.15626775]]]...]
Running backward pass on section 1[[[[-0.0038170468 0.001159993 0.0018510161]]]...]
Forward pass over section 0[[[[0 0 0]]]...]
Running backward pass on section 0[-0.098705873 0.026503615 -0.0099251084...]