MultiRNNCell with shared weights

A MultiRNNCell object that stacks several LSTM cell with weight sharing.


In [7]:
import tensorflow as tf  # 0.12.1
import numpy as np

import collections
# import math

# These can be simplified with tf.<fn> instead
# from tensorflow.python.ops.math_ops import sigmoid
# from tensorflow.python.ops.math_ops import tanh

# from tensorflow.python.framework import ops
# from tensorflow.python.ops import array_ops
# from tensorflow.python.ops import clip_ops
# from tensorflow.python.ops import nn_ops
# from tensorflow.python.ops import math_ops
# from tensorflow.python.ops import init_ops
# from tensorflow.python.ops import variable_scope as vs

# from tensorflow.python.util import nest  # some Tensorflow magic

# Some jupyter magic
# automatic debugging
# %pdb

Custom LSTM Cell


In [3]:
# Heavily based on original tf.nn.rnn_cell_LSTMCell()
# with modifications to variable_scope to allow for Tensorflow 
# weight sharing.


# LSTMCell helper functions
def _get_concat_variable(name, shape, dtype, num_shards):
    """Get a sharded variable concatenated into one tensor."""
    sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
    if len(sharded_variable) == 1:
        return sharded_variable[0]

    concat_name = name + "/concat"
    concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
    for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
        if value.name == concat_full_name:
            return value

    concat_variable = array_ops.concat(0, sharded_variable, name=concat_name)
    ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
                          concat_variable)
    return concat_variable


# LSTMCell helper functions
def _get_sharded_variable(name, shape, dtype, num_shards):
    """Get a list of sharded variables with the given dtype."""
    if num_shards > shape[0]:
        raise ValueError("Too many shards: shape=%s, num_shards=%d" %
                       (shape, num_shards))
    unit_shard_size = int(math.floor(shape[0] / num_shards))
    remaining_rows = shape[0] - unit_shard_size * num_shards

    shards = []
    for i in range(num_shards):
        current_size = unit_shard_size
        if i < remaining_rows:
            current_size += 1
        shards.append(vs.get_variable(name + "_%d" % i, [current_size] + shape[1:],
                                      dtype=dtype))
    return shards


class LSTMCell(tf.nn.rnn_cell.RNNCell):
    """Long short-term memory unit (LSTM) recurrent network cell.

    The default non-peephole implementation is based on:

        http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf

    S. Hochreiter and J. Schmidhuber.
    "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.

    The peephole implementation is based on:

        https://research.google.com/pubs/archive/43905.pdf

    Hasim Sak, Andrew Senior, and Francoise Beaufays.
    "Long short-term memory recurrent neural network architectures for
     large scale acoustic modeling." INTERSPEECH, 2014.

    The class uses optional peep-hole connections, optional cell clipping, and
    an optional projection layer.
    """
    def __init__(self, num_units, input_size=None,
                 use_peepholes=False, cell_clip=None,
                 initializer=None, num_proj=None, proj_clip=None,
                 num_unit_shards=1, num_proj_shards=1,
                 forget_bias=1.0, state_is_tuple=True,
                 activation=tanh):
        """Initialize the parameters for an LSTM cell.

        Args:
            num_units: int, The number of units in the LSTM cell
            input_size: Deprecated and unused.
            use_peepholes: bool, set True to enable diagonal/peephole connections.
            cell_clip: (optional) A float value, if provided the cell state is clipped
                by this value prior to the cell output activation.
            initializer: (optional) The initializer to use for the weight and
                projection matrices.
            num_proj: (optional) int, The output dimensionality for the projection
                matrices.  If None, no projection is performed.
            proj_clip: (optional) A float value.  If `num_proj > 0` and `proj_clip` is
            provided, then the projected values are clipped elementwise to within
            `[-proj_clip, proj_clip]`.
            num_unit_shards: How to split the weight matrix.  If >1, the weight
                matrix is stored across num_unit_shards.
            num_proj_shards: How to split the projection matrix.  If >1, the
                projection matrix is stored across num_proj_shards.
            forget_bias: Biases of the forget gate are initialized by default to 1
                in order to reduce the scale of forgetting at the beginning of
                the training.
            state_is_tuple: If True, accepted and returned states are 2-tuples of
                the `c_state` and `m_state`.  If False, they are concatenated
                along the column axis.  This latter behavior will soon be deprecated.
            activation: Activation function of the inner states.
        """
        if not state_is_tuple:
            tf.logging.warn("%s: Using a concatenated state is slower and will soon be "
                            "deprecated.  Use state_is_tuple=True.", self)
        if input_size is not None:
            tf.logging.warn("%s: The input_size parameter is deprecated.", self)
        self._num_units = num_units
        self._use_peepholes = use_peepholes
        self._cell_clip = cell_clip
        self._initializer = initializer
        self._num_proj = num_proj
        self._proj_clip = proj_clip
        self._num_unit_shards = num_unit_shards
        self._num_proj_shards = num_proj_shards
        self._forget_bias = forget_bias
        self._state_is_tuple = state_is_tuple
        self._activation = activation

        if num_proj:
            self._state_size = (
                tf.nn.rnn_cell.LSTMStateTuple(num_units, num_proj)
                if state_is_tuple else num_units + num_proj)
            self._output_size = num_proj
        else:
            self._state_size = (
                tf.nn.rnn_cell.LSTMStateTuple(num_units, num_units)
                if state_is_tuple else 2 * num_units)
            self._output_size = num_units

    @property
    def state_size(self):
        return self._state_size

    @property
    def output_size(self):
        return self._output_size

    def __call__(self, inputs, state, scope=None):
        """Run one step of LSTM.

        Args:
            inputs: input Tensor, 2D, batch x num_units.
            state: if `state_is_tuple` is False, this must be a state Tensor,
                `2-D, batch x state_size`.  If `state_is_tuple` is True, this must be a
                tuple of state Tensors, both `2-D`, with column sizes `c_state` and
                `m_state`.
        scope: VariableScope for the created subgraph; defaults to "LSTMCell".

        Returns:
            A tuple containing:

            - A `2-D, [batch x output_dim]`, Tensor representing the output of the
                LSTM after reading `inputs` when previous state was `state`.
                Here output_dim is:
                    num_proj if num_proj was set,
                    num_units otherwise.
            - Tensor(s) representing the new state of LSTM after reading `inputs` when
                the previous state was `state`.  Same type and shape(s) as `state`.

        Raises:
            ValueError: If input size cannot be inferred from inputs via
                static shape inference.
        """
        num_proj = self._num_units if self._num_proj is None else self._num_proj

        if self._state_is_tuple:
            (c_prev, m_prev) = state
        else:
            c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units])
            m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj])

        dtype = inputs.dtype
        input_size = inputs.get_shape().with_rank(2)[1]
        if input_size.value is None:
            raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
        with vs.variable_scope(scope or type(self).__name__,
                               initializer=self._initializer):  # "LSTMCell"
            concat_w = _get_concat_variable(
                "W", [input_size.value + num_proj, 4 * self._num_units],
                dtype, self._num_unit_shards)

            b = vs.get_variable(
                "B", shape=[4 * self._num_units],
                initializer=init_ops.zeros_initializer, dtype=dtype)

            # i = input_gate, j = new_input, f = forget_gate, o = output_gate
            cell_inputs = array_ops.concat(1, [inputs, m_prev])
            lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b)
            i, j, f, o = array_ops.split(1, 4, lstm_matrix)

            # Diagonal connections
            if self._use_peepholes:
                w_f_diag = vs.get_variable(
                    "W_F_diag", shape=[self._num_units], dtype=dtype)
                w_i_diag = vs.get_variable(
                    "W_I_diag", shape=[self._num_units], dtype=dtype)
                w_o_diag = vs.get_variable(
                    "W_O_diag", shape=[self._num_units], dtype=dtype)

            if self._use_peepholes:
                c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
                     sigmoid(i + w_i_diag * c_prev) * self._activation(j))
            else:
                c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * 
                     self._activation(j))

            if self._cell_clip is not None:
                # pylint: disable=invalid-unary-operand-type
                c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
                # pylint: enable=invalid-unary-operand-type

            if self._use_peepholes:
                m = sigmoid(o + w_o_diag * c) * self._activation(c)
            else:
                m = sigmoid(o) * self._activation(c)

            if self._num_proj is not None:
                concat_w_proj = _get_concat_variable(
                    "W_P", [self._num_units, self._num_proj],
                        dtype, self._num_proj_shards)

                m = math_ops.matmul(m, concat_w_proj)
                if self._proj_clip is not None:
                    # pylint: disable=invalid-unary-operand-type
                    m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
                    # pylint: enable=invalid-unary-operand-type

        print('DEBUG')
        new_state = (tf.nn.rnn_cell.LSTMStateTuple(c, m) if self._state_is_tuple
                 else array_ops.concat(1, [c, m]))
        return m, new_state

'''
Custom LSTMCell

Based on http://colah.github.io/posts/2015-08-Understanding-LSTMs/

Modified tensorflow.nn.rnn_cell.LSTMCell to remove variable scopes.

This cell allows us to customise variable scope which in turn
lets us share variables between cells. It is otherwise identical.
''' 
class BasicLSTMCell_shared_weights(tf.nn.rnn_cell.RNNCell):
    def __init__(self, num_units, forget_bias=1.0, input_size=None,
                 activation=tanh):
        self._num_units = num_units
        self._forget_bias = forget_bias
        self._activation = activation

    @property
    def state_size(self):
        return tf.nn.rnn_cell.LSTMStateTuple(self._num_units, self._num_units)

    @property
    def output_size(self):
        return self._num_units

    def __call__(self, inputs, state, scope=None):
        """Long short-term memory cell (LSTM)."""
    
        c, h = state
        
        # concat
        # concat = _linear([inputs, h], 4 * self._num_units, True)
        
        args = [inputs, h]
        
        total_arg_size = 0
        shapes = [a.get_shape().as_list() for a in args]
        for shape in shapes:
            total_arg_size += shape[1]
        
        output_size = 4 * self._num_units
        
        # shared
        with tf.variable_scope('LSTMCell_shared_weights'):
            matrix = tf.get_variable('W', [total_arg_size, output_size])
            bias_term = tf.get_variable('b', [output_size],
                initializer=tf.constant_initializer(0.0))
            
        res = tf.matmul(tf.concat(1, args), matrix) + bias_term
        
        # i = input_gate, j = new_input, f = forget_gate, o = output_gate
        i, j, f, o = tf.split(1, 4, res)
        
        new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) * 
                 self._activation(j))
        new_h = self._activation(new_c) * sigmoid(o)
        
        new_state = tf.nn.rnn_cell.LSTMStateTuple(new_c, new_h)
        return new_h, new_state

'''
Custom MultiRNNCell wrapper

Exactly the Same as tf.nn.rnn_cell.MultiRNNCell but without the 
name scoping. The original wrapper adds /cell# to the name scope,
which would have created new variables instead of reusing exising
ones.

This class enables you to stack LSTMCells ontop of each other with
the same input. The multicell simply passes the input and output to
the cells in the stack.
'''
class MultiRNNCell_shared_weights(tf.nn.rnn_cell.RNNCell):
    """RNN cell composed sequentially of multiple simple cells."""
    def __init__(self, cells):
        self._cells = cells

    @property
    def state_size(self):
        return tuple(cell.state_size for cell in self._cells)

    @property
    def output_size(self):
        return self._cells[-1].output_size

    def __call__(self, inputs, state, scope=None):
        """Run this multi-layer cell on inputs, starting from state."""
        cur_state_pos = 0
        cur_inp = inputs
        new_states = []
        for i, cell in enumerate(self._cells):
            cur_state = state[i]
            cur_inp, new_state = cell(cur_inp, cur_state)
            new_states.append(new_state)
        new_states = tuple(new_states)
        return cur_inp, new_states

# From tf.nn.rnn_cell._linear()
# Original authors: tensorflow
# This could be expanded to n-dimensions later
# def _linear(args, output_size, bias, bias_start=0.0, scope=None):
#     """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.

#     Args:
#         args: a 2D Tensor or a list of 2D, batch x n, Tensors.
#         output_size: int, second dimension of W[i].
#         bias: boolean, whether to add a bias term or not.
#         bias_start: starting value to initialize the bias; 0 by default.
#         scope: VariableScope for the created subgraph; defaults to "Linear".

#     Returns:
#         A 2D Tensor with shape [batch x output_size] equal to
#         sum_i(args[i] * W[i]), where W[i]s are newly created matrices.

#     Raises:
#         ValueError: if some of the arguments has unspecified or wrong shape.
#     """
#     if args is None or (nest.is_sequence(args) and not args):
#         raise ValueError("`args` must be specified")
#     if not nest.is_sequence(args):
#         args = [args]

#     # Calculate the total size of arguments on dimension 1.
#     total_arg_size = 0
#     shapes = [a.get_shape().as_list() for a in args]
#     for shape in shapes:
#     if len(shape) != 2:
#         raise ValueError("Linear is expecting 2D arguments: %s" % str(shapes))
#     if not shape[1]:
#         raise ValueError("Linear expects shape[1] of arguments: %s" % str(shapes))
#     else:
#         total_arg_size += shape[1]

#         dtype = [a.dtype for a in args][0]

#         # Now the computation.
#         with tf.variable_scope(scope or "Linear"):
#             matrix = tf.get_variable(
#                 "Matrix", [total_arg_size, output_size], dtype=dtype)
#         if len(args) == 1:
#             res = tf.matmul(args[0], matrix)
#         else:
#             res = tf.matmul(tf.concat(1, args), matrix)
#         if not bias:
#             return res
#         bias_term = tf.get_variable(
#             "Bias", [output_size],
#             dtype=dtype,
#             initializer=tf.constant_initializer(
#                 bias_start, dtype=dtype))
#         return res + bias_term

LSTM equations

$i = \sigma(x_tU^i + s_{t-1}W^i + b^i) \\ f = \sigma(x_tU^f + s_{t-1}W^f + b^f) \\ o = \sigma(x_tU^o + s_{t-1}W^o + b^o) \\ g = \tanh(x_tU^g + s_{t-1}W^g + b^g) \\ c_t = c_{t-1} \odot f + g \odot i \\ s_t = \tanh(c_t) \odot o $

Matrix and vector dimensions

  • $U \text{ has dimensions } n \times m$ (hidden to input)
  • $W \text{ has dimensions } n \times n$ (hidden to hidden)
  • One set of matrixes for each of the three gates {$i, f, o$}
  • One extra set for the cell state
  • We also use an optional bias of dimension $n$ for each gate and the state

$4(n \times m) + 4(n^2) + 4n\\ 4n \times 4m + 4n^2 + 4n\\ 4(nm + n^2 + n)$

$n = 32\\ m = 8$

Total parameters: $4(32\times8 + 32^2 + 32) = 5248$

$\displaystyle \\ x_t \in \mathbb{R}^8 \\ o_t \in \mathbb{R}^8 \\ s_t \in \mathbb{R}^{32} \\ U^i \in \mathbb{R}^{32 \times 8} \\ U^f \in \mathbb{R}^{32 \times 8} \\ U^o \in \mathbb{R}^{32 \times 8} \\ U^g \in \mathbb{R}^{32 \times 8} \\ W^i \in \mathbb{R}^{32 \times 32} \\ W^f \in \mathbb{R}^{32 \times 32} \\ W^o \in \mathbb{R}^{32 \times 32} \\ W^g \in \mathbb{R}^{32 \times 32} \\ b^i \in \mathbb{R}^{32} \\ b^f \in \mathbb{R}^{32} \\ b^o \in \mathbb{R}^{32} \\ b^g \in \mathbb{R}^{32}$

Build graph


In [40]:
import numpy as np
import tensorflow as tf


def reset_graph():
    '''
    Convenience function to reset graph at each run.
    If the graph is left open Tensorflow automatically appends new
    graph functions making things slow.'''
    if 'sess' in globals() and sess:
        sess.close()  # Close any open session
    tf.reset_default_graph()


def build_multicell_lstm_graph(
    state_size=16,  #  c + h, use num_units * 2 instead
    num_units=32,
    batch_size=8,
    num_steps=8,
    num_classes=8,
    num_cells=1,
    learning_rate=1e-4):
    '''
    This function builds a dynamic rnn graph with multiple LSTM cells.
    The LSTM cells share weights between them though the modified
    MultiRNNCell wrapper.
    
    Vanilla RNN equations:
    st = tanh(Uxt + Wst-1)
    ot softmax(Vst)
    
    H = 8 (num_units)
    C = 2 (grid_size)
    
    xt = 2 (grid_size)
    ot = 2
    st = 8 (num_units)
    U = 8 x 2
    V = 2 x 8
    W = 8 x 8
    
    '''
    # reset graph for testing purposes
    reset_graph()
    
    # Get dataset
    Dataset = build_1d_dataset(n=8, num_samples=1000)
    
    # Create placeholder for minibatch
    # (num_samples, timesteps, depth)
    x = tf.placeholder(tf.float32, [None, num_steps, 1])
    y = tf.placeholder(tf.int32, [None])
    
    # shape (batch_size, num_steps, state_size)
#     rnn_inputs = tf.nn.embedding_lookup(tf.random_normal([num_classes, state_size]), x)

    # Create variables
#     with tf.variable_scope('RNN') as scope:  # Save scope in variable
#         # For LSTM use [2n x 4n] weights for all 4 gates to allow us
#         # to multiply them all at once, saving computation.
#         W = tf.get_variable('LSTMCell/W_0',
#                             [2 * num_units, 4 * num_units])
#         # Similarly, we have [4n] biases
#         b = tf.get_variable('LSTMCell/B', [4 * num_units],
#                             initializer=tf.constant_initializer(0.0))

    # Mark scope to reuse from now on. ie: any variable created using
    # this scope will reuse variables if they are found with get_variable().
    # To create new variables set scope reuse=False or call this function
    # again.
#     scope.reuse_variables()
    
    cell = tf.nn.rnn_cell.LSTMCell(num_units)
    
    # Use this to stack several cells in a tower
#     cell = MultiRNNCell_shared_weights([cell] * num_cells)  # custom wrapper
    
    # Cells are unrolled in tf.nn.dynamic_rnn and weights are shared between them
#     rnn_outputs, final_state = tf.nn.dynamic_rnn(cell, rnn_inputs, dtype=tf.float32, scope=scope)
    rnn_outputs, final_state = tf.nn.dynamic_rnn(cell, x, dtype=tf.float32)

    with tf.variable_scope('softmax'):  # Save scope in variable
        W = tf.get_variable('W', [num_units, num_units])
        b = tf.get_variable('b', [num_units], initializer=tf.constant_initializer(0.0))
#     logits = [tf.matmul(rnn_output, W) + b for rnn_output in rnn_outputs]
    
    # flatten to apply weights to all timesteps
    rnn_output = tf.reshape(rnn_outputs, [-1, num_units])
    prediction = tf.nn.softmax(tf.matmul(rnn_output, W) + b)
    prediction = tf.reshape(prediction, [-1, num_steps, 1])
    
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(prediction, y))
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)
    
    return dict(
        x = x,
        y = y,
        final_state = final_state,
        total_loss = total_loss,
        train_step = train_step
    )
    
# build_multicell_lstm_graph()

Training and evaluations


In [ ]:
def train_network(g, num_epochs,
                 num_steps=200,
                 batch_size=8,
                 verbose=True,
                 save=False):
    tf.set_random_seed(42)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        training_losses = []
        for i, epoch in enumerate(gen_epochs())

Testing shared variables


In [29]:
'''
Build graph and test if we only have two variables
'''
build_multicell_lstm_graph()

def count_parameters_and_variables():
    total_parameters = 0
    for variable in tf.trainable_variables():
        shape = variable.get_shape()
        variable_parameters = 1
        for dim in shape:
            variable_parameters *= dim.value
        print('{}, shape={}, params={}'.format(variable.name, shape, variable_parameters))
        total_parameters += variable_parameters
    print('total_parameters:', total_parameters)
    
    g = tf.get_default_graph()
    vars = g.get_collection('variables')
    print(len(vars), 'variables')
    
    return len(vars), total_parameters

assert count_parameters_and_variables()[0] == 2


RNN/LSTMCell/W_0:0, shape=(33, 128), params=4224
RNN/LSTMCell/B:0, shape=(128,), params=128
total_parameters: 4352
2 variables

Generating data


In [18]:
# import numpy as np 
# from collections import namedtuple

def generate_samples(num_samples=1000):
    pass

def build_1d_dataset(
    n_length=8,  # length of the grid
    num_samples=100000,
    k_value=2,  # number of colours k=2: binary
    train_split=0.8,  # fraction of full dataset for training
    valid_split=0.5,  # fraction of test set for validation
#     random_distribution=True,
    verbose=False):
    '''build 1d board training dataset of num_samples with n length 
    for machine learning.
    
    The boards are randomly distributed binary "stones" (1 or 0)
    
    inputs:
        n: length of the board (1d)
        num_samples: how many samples
        k_value: number of colors of the grid default k=2 is binary
        train_split: 
        valid_split: 
        verbose: show statistics
        
    returns:
        A named tuple 'Dataset' with train, valid and test datasets
            of (data, labels) length
    '''
    
    # Generate random num_samples with n length in shape=(num_samples, n)
    # It generates the random distribution from 0 to k_value
    data = np.random.randint(0, k_value, size=[num_samples, n_length])
    
    # Generate labels from data
    # The arbitrary problem for the machine is to find the
    # connection length from left to right.
    labels = np.zeros(num_samples, dtype=int)
    for i, board in enumerate(data):

#         if np.sum(board, axis=0) == n:  # quickly get fully connected boards
#             labels[i] = 1
#         else:
#             labels[i] = 0

        # Stepwise look for 1's per grid
        connection_length = 0
        for j, grid in enumerate(board):
            if grid == 1:
                connection_length += 1 
            else:
                break  # stop looking to save some computation
        labels[i] = connection_length
    
    # Create dataset named tuple
    Dataset = namedtuple('Dataset', ['train', 'valid', 'test'])
    
    # Split dataset
    num_train = int(num_samples * train_split)
    num_valid = int((num_samples - num_train) * valid_split)
    
    Dataset.train = (data[:num_train], labels[:num_train])
    Dataset.valid = (data[:num_valid], labels[:num_valid])
    Dataset.test = (data[:num_valid], labels[:num_valid])
    
    return Dataset

# %time build_1d_dataset();


Out[18]:
10000

In [ ]: