TensorFlow Scan

tensorflow.scan allows for loops to be written inside a computation graph (which using explicit loop structures like for) -- backpropagation is handled implicitly by tensorflow. Explicitly unrolling the loops requires the creation of new graph nodes for each loop body iteration (although the number of iterations is fixed).

Cumulative Sum


In [3]:
import tensorflow as tf

def fn(previous_output, current_input):
    return previous_output + current_input

elems = tf.Variable([1.0, 2.0, 2.0, 2.0])
elems = tf.identity(elems)
#required otherwise it will fail;

initializer = tf.constant(0.0)
out = tf.scan(fn, elems, initializer=initializer)

with tf.Session() as session:
    init_op = tf.initialize_all_variables()
    session.run(init_op)
    value = session.run(out)
    print(value)


[ 1.  3.  5.  7.]

Loop Equivalence


In [18]:
import tensorflow as tf

def fn(previous_output, current_input):
    return previous_output + current_input

elems = tf.Variable([1.0, 2.0, 2.0, 2.0])
elems = tf.identity(elems)
#required otherwise it will fail;

initializer = tf.constant(0.0)
cum_sum = tf.Variable(0.0)

for x in tf.split(0, elems.get_shape()[0], elems):
    cum_sum += x
    
with tf.Session() as session:
    init_op = tf.initialize_all_variables()
    session.run(init_op)
    value = session.run(cum_sum)
    print(value)
    #like a reduce operation (but it scans over elements)


[ 7.]

In [ ]: