Text generator based on RNN

Brief

Generate fake abstract with RNN model under TensorFlow r1.4.

Import libraries


In [1]:
import tensorflow as tf
import numpy as np
import random
import os

Configurations


In [2]:
vocab = (" $%'()+,-./0123456789:;=?ABCDEFGHIJKLMNOPQRSTUVWXYZ"
            "\\^_abcdefghijklmnopqrstuvwxyz{|}\n")
graph_path = r"./graphs"
test_text_path = os.path.normpath(r"../Dataset/arvix_abstracts.txt")
batch_size=50
model_param_path=os.path.normpath(r"./model_checkpoints")

Data encoding

Basic Assumption

  • A full string sequence consists $START$ & $STOP$ signal with characters in the middle.

Encoding policy

  • A set $\mathcal{S}$ that consists of many characters is utilized to encode the characters.
  • The $1^{st}$ entry of the vector corresponds to $UNKNOWN$ characters(l.e. characters that are beyond $\mathcal{S}$).
  • The last entry of the vector corresponds to $STOP$ signal of the sequence.
  • The entries in the middle corresponds to the indices of the characters within $\mathcal{S}$.
  • The $START$ signal is represented as a zero vector.

Implementation & Test

Declaration

In [3]:
class TextCodec:
    def __init__(self, vocab):
        self._vocab = vocab
        self._dim = len(vocab) + 2

    def encode(self, string, sess = None, start=True, stop=True):
        """
        Encode string.
        Each character is represented as a N-dimension one hot vector. 
        N = len(self._vocab)+ 2
        
        Note:
        The first entry of the vector corresponds to unknown character. 
        The last entry of the vector corresponds to STOP signal of the sequence. 
        The entries in the middle corresponds to the index of the character. 
        The START signal is represented as a zero vector. 
        """
        tensor = [vocab.find(ch) + 1 for ch in string]
        if stop:
             tensor.append(len(vocab)+1)  # String + STOP
        tensor = tf.one_hot(tensor,depth=len(vocab) + 2,on_value=1.0,off_value=0.0,axis=-1, dtype=tf.float32)
        if start:
            tensor=tf.concat([tf.zeros([1, len(vocab) + 2],dtype=tf.float32),tensor],axis=0)  # String + START
        if sess is None:
            with tf.Session() as sess:
                nparray=tensor.eval()
        elif type(sess) == tf.Session:
            nparray = tensor.eval(session=sess)
        else:
            raise TypeError('"sess" must be {}, got {}'.format(tf.Session, type(sess)))    
        return nparray

    def decode(self, nparray, default="[UNKNOWN]",start="[START]",stop="[STOP]",strip=False):
        text_list = []
        indices=np.argmax(nparray, axis=1)
        for v, ch_i in zip(nparray,indices):
            if np.all(v==0):
                text_list.append(start if not strip else "")
            elif ch_i==0:
                text_list.append(default)
            elif ch_i==len(self._vocab)+1:
                text_list.append(stop if not strip else "")
            else:
                text_list.append(vocab[ch_i-1])
        return "".join(text_list)
    
    @property
    def dim(self):
        return self._dim
Test

See how encoding and decoding work.


In [4]:
test_codec=TextCodec(vocab)
test_text_encoded=test_codec.encode("Hello world!")
print("Encoded text looks like:\n{}".format(test_text_encoded))
test_text_decoded=test_codec.decode(nparray=test_text_encoded,strip=False)
print("Decoded text looks like:\n{}".format(test_text_decoded))


Encoded text looks like:
[[ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 ..., 
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 1.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  1.]]
Decoded text looks like:
[START]Hello world[UNKNOWN][STOP]

Load data set


In [5]:
with open(test_text_path, "r") as f:
    raw_text_list = "".join(f.readlines()).split("\n")
print("Loaded abstract from a total of {} theses.".format(len(raw_text_list)))
# See what we have loaded
sample_text_no = random.randint(0, len(raw_text_list)-1)
sample_text_raw = raw_text_list[sample_text_no]
print("A sample text in the data set:\n{}".format(sample_text_raw))
sample_text_encoded=test_codec.encode(sample_text_raw)
print("Encoded text:\n{}".format(sample_text_encoded))
print("Decoded text:\n{}".format(test_codec.decode(sample_text_encoded)))
encoded_data = test_codec.encode("\n".join(raw_text_list), start=False, stop=False)


Loaded abstract from a total of 7201 theses.
A sample text in the data set:
The generalization error of deep neural networks via their classification margin is studied in this work, providing novel generalization error bounds that are independent of the network depth, thereby avoiding the common exponential depth-dependency which is unrealistic for current networks with hundreds of layers. We show that a large margin linear classifier operating at the output of a deep neural network induces a large classification margin at the input of the network, provided that the network preserves distances in directions normal to the decision boundary. The distance preservation is characterized by the average behaviour of the network's Jacobian matrix in the neighbourhood of the training samples. The introduced theory also leads to a margin preservation regularization scheme that outperforms weight decay both theoretically and empirically.
Encoded text:
[[ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 ..., 
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  1.]]
Decoded text:
[START]The generalization error of deep neural networks via their classification margin is studied in this work, providing novel generalization error bounds that are independent of the network depth, thereby avoiding the common exponential depth-dependency which is unrealistic for current networks with hundreds of layers. We show that a large margin linear classifier operating at the output of a deep neural network induces a large classification margin at the input of the network, provided that the network preserves distances in directions normal to the decision boundary. The distance preservation is characterized by the average behaviour of the network's Jacobian matrix in the neighbourhood of the training samples. The introduced theory also leads to a margin preservation regularization scheme that outperforms weight decay both theoretically and empirically.[STOP]

Define Batch Generator


In [6]:
def batch_generator(data, codec, batch_size, seq_length, reset_every):
    if type(data) == str:
        data=codec.encode(data, start=False, stop=False)
    head = 0
    reset_index = 0
    batch = []
    seq = []
    increment = seq_length * reset_every - 1
    extras = codec.encode("", start=True, stop=True)
    v_start, v_stop = extras[0: 1, :], extras[1: 2, :]
    while head < np.shape(data)[0] or len(batch) == batch_size:
        if len(batch) == batch_size:
            batch = np.array(batch)
            for offset in range(reset_every):
                yield (batch[:, offset * seq_length: (offset + 1) * seq_length, :], 
                batch[:, offset * seq_length + 1: (offset + 1) * seq_length + 1, :])
            batch = []
        else:
            seq = np.concatenate([v_start, data[head: head + increment, :], v_stop], axis=0)
            if np.shape(seq)[0] == (increment + 2):
                batch.append(seq)
            head += increment

Check the generator


In [7]:
seq_length = 100
reset_every = 2
batch_size = 2
batches = batch_generator(data=encoded_data, 
                               codec=test_codec, 
                               batch_size=batch_size, 
                               seq_length=seq_length, 
                               reset_every=reset_every)
for (x, y), i in zip(batches, range(reset_every * 2)):
    print("Batch {}".format(i))
    if (i % reset_every) == 0:
        print("Reset")
    for j in range(batch_size):
        decoded_x, decoded_y = test_codec.decode(x[j], strip=False), test_codec.decode(y[j], strip=False)
        print("Index of sub-sequence:\n{}\nSequence input:\n{}:\nSequence output:\n{}".format(j, 
                                                                                          decoded_x, 
                                                                                          decoded_y))
del seq_length, reset_every, batch_size, batches


Batch 0
Reset
Index of sub-sequence:
0
Sequence input:
[START]In science and engineering, intelligent processing of complex signals such as images, sound or lang:
Sequence output:
In science and engineering, intelligent processing of complex signals such as images, sound or langu
Index of sub-sequence:
1
Sequence input:
[START]gically inspired. Hierarchical systems (or, more generally, nested systems) offer a way to generate:
Sequence output:
gically inspired. Hierarchical systems (or, more generally, nested systems) offer a way to generate 
Batch 1
Index of sub-sequence:
0
Sequence input:
uage is often performed by a parameterized hierarchy of nonlinear processing layers, sometimes biolo:
Sequence output:
age is often performed by a parameterized hierarchy of nonlinear processing layers, sometimes biolo[STOP]
Index of sub-sequence:
1
Sequence input:
 complex mappings using simple stages. Each layer performs a different operation and achieves an eve:
Sequence output:
complex mappings using simple stages. Each layer performs a different operation and achieves an eve[STOP]
Batch 2
Reset
Index of sub-sequence:
0
Sequence input:
[START]r more sophisticated representation of the input, as, for example, in an deep artificial neural net:
Sequence output:
r more sophisticated representation of the input, as, for example, in an deep artificial neural netw
Index of sub-sequence:
1
Sequence input:
[START]ation of the parameters of all the layers and selection of an optimal architecture is widely consid:
Sequence output:
ation of the parameters of all the layers and selection of an optimal architecture is widely conside
Batch 3
Index of sub-sequence:
0
Sequence input:
work, an object recognition cascade in computer vision or a speech front-end processing. Joint estim:
Sequence output:
ork, an object recognition cascade in computer vision or a speech front-end processing. Joint estim[STOP]
Index of sub-sequence:
1
Sequence input:
ered to be a difficult numerical nonconvex optimization problem, difficult to parallelize for execut:
Sequence output:
red to be a difficult numerical nonconvex optimization problem, difficult to parallelize for execut[STOP]

Define model class


In [8]:
class DRNN(tf.nn.rnn_cell.RNNCell):
    def __init__(self, input_dim, hidden_dim, output_dim, num_hidden_layer, dtype=tf.float32):
        super(tf.nn.rnn_cell.RNNCell, self).__init__(dtype=dtype)
        assert type(input_dim) == int and input_dim > 0, "Invalid input dimension. "
        self._input_dim = input_dim
        assert type(num_hidden_layer) == int and num_hidden_layer > 0, "Invalid number of hidden layer. "
        self._num_hidden_layer = num_hidden_layer
        assert type(hidden_dim) == int and hidden_dim > 0, "Invalid dimension of hidden states. "
        self._hidden_dim = hidden_dim
        assert type(output_dim) == int and output_dim > 0, "Invalid dimension of output dimension. "
        self._output_dim = output_dim
        self._state_is_tuple = True
        with tf.variable_scope("input_layer"):
            self._W_xh = tf.get_variable("W_xh", shape=[self._input_dim, self._hidden_dim])
            self._b_xh = tf.get_variable("b_xh", shape=[self._hidden_dim])
        with tf.variable_scope("rnn_layers"):
            self._cells = [tf.nn.rnn_cell.GRUCell(self._hidden_dim) for _ in range(num_hidden_layer)]
        with tf.variable_scope("output_layer"):
            self._W_ho_list = [tf.get_variable("W_h{}o".format(i), shape=[self._hidden_dim, self._output_dim])
                               for i in range(num_hidden_layer)]
            self._b_ho = tf.get_variable("b_ho", shape=[self._output_dim])

    @property
    def output_size(self):
        return self._output_dim

    @property
    def state_size(self):
        return (self._hidden_dim,) * self._num_hidden_layer

    def zero_state(self, batch_size, dtype):
        if self._state_is_tuple:
            return tuple(cell.zero_state(batch_size, dtype)for cell in self._cells)
        else:
            raise NotImplementedError("Not implemented yet.")

    def __call__(self, _input, state, scope=None):
        assert type(state) == tuple and len(state) == self._num_hidden_layer, "state must be a tuple of size {}".format(
            self._num_hidden_layer)
        hidden_layer_input = tf.matmul(_input, self._W_xh) + self._b_xh
        prev_output = hidden_layer_input
        final_state = []
        output = None
        for hidden_layer_index, hidden_cell in enumerate(self._cells):
            with tf.variable_scope("cell_{}".format(hidden_layer_index)):
                new_output, new_state = hidden_cell(prev_output, state[hidden_layer_index])
                prev_output = new_output + hidden_layer_input  # Should be included in variable scope of this layer or?
                final_state.append(new_state)
            _W_ho = self._W_ho_list[hidden_layer_index]
            if output is None:
                output = tf.matmul(new_output, _W_ho)
            else:
                output = output + tf.matmul(new_output, _W_ho)
        output = tf.tanh(output + self._b_ho)
        # output = tf.nn.relu(output)
        final_state = tuple(final_state)
        return output, final_state

    def inspect_weights(self, sess):
        val = self._W_xh.eval(sess)
        print("W_xh:\n{}\nF-norm:\n{}".format(val, norm(val)))
        val = self._b_xh.eval(sess)
        print("b_xh:\n{}\nF-norm:\n{}".format(val, norm(val)))
        for hidden_layer_index in range(self._num_hidden_layer):
            val = self._W_ho_list[hidden_layer_index].eval(sess)
            print("W_h{}o:\n{}\nF-norm:\n{}".format(hidden_layer_index, val, norm(val)))
        val = self._b_ho.eval(sess)
        print("b_ho:\n{}\nF-norm:\n{}".format(val, norm(val)))

Make an instance of the model and define the rest of the graph

Thoughts

If GRU is used, then the outputs of GRU shall not be directly used as desired output without further transforms. (e.g. A cell accpet 2 inputs, a state from the previous cell and the input of this cell(which is approximated by the state input), then the RNN cell can be treated as a normal feed forward network.

The proposal above is to be tested again due to the previous bug in training (Failed to feed the initial state given by the RNN output from last sequnce)


In [9]:
tf.reset_default_graph()
input_dim = output_dim = test_codec.dim
hidden_dim = 700
num_hidden_layer = 3
rnn_cell = DRNN(input_dim=input_dim, output_dim=output_dim, num_hidden_layer=num_hidden_layer, hidden_dim=hidden_dim)
batch_size = 50
init_state = tuple(tf.placeholder_with_default(input=tensor, 
                                         shape=[None, hidden_dim]) for tensor in rnn_cell.zero_state(
    batch_size=batch_size, dtype=tf.float32))
seq_input = tf.placeholder(name="batch_input", shape=[None, None, input_dim], dtype=tf.float32)
target_seq_output = tf.placeholder(name="target_batch_output", shape=[None, None, output_dim], dtype=tf.float32)
seq_output, final_states = tf.nn.dynamic_rnn(cell=rnn_cell,inputs=seq_input, 
                                                      initial_state=init_state, dtype=tf.float32)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=target_seq_output, logits=seq_output))
summary_op = tf.summary.scalar(tensor=loss, name="loss")
global_step = tf.get_variable(name="global_step", initializer=0, trainable=False)
lr = tf.get_variable(name="learning_rate", initializer=1.0, trainable=False)

Training


In [ ]:
n_epoch=50
learning_rate=1e-3
train_op=tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss, global_step=global_step)
print_every = 50
save_every = 1000
partition_size = 100
logdir = os.path.normpath("./graphs")
seq_length = 100
reset_every = 100
visualize_every = 100
learning_rate_decay = 0.9
# batch_size has been specified when configuring the the tensors for initial states

keep_checkpoint_every_n_hours = 0.5
model_checkpoint_dir = os.path.normpath("./model_checkpoints")
model_checkpoint_path = os.path.join(model_checkpoint_dir, "DRNN")
saver = tf.train.Saver(keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)
batches = list(batch_generator(data=encoded_data, 
                               codec=test_codec, 
                               batch_size=batch_size, 
                               seq_length=seq_length, 
                               reset_every=reset_every))
with tf.Session() as sess, tf.summary.FileWriter(logdir=logdir) as writer:
    sess.run(tf.global_variables_initializer())
    feed_dict = dict()
    states = None
    sess.run(tf.assign(lr, learning_rate))
    zero_states = sess.run(rnn_cell.zero_state(batch_size=1, dtype=tf.float32))
    for epoch in range(n_epoch):
        assert lr.eval(sess) > 0, "learning_rate must be positive."
        for i, (x, y) in enumerate(batches):
            feed_dict = {seq_input: x, target_seq_output: y}
            if (i % reset_every) != 0 and states is not None:
                for j in range(len(init_state)):
                    feed_dict[init_state[j]] = states[j]
            _, summary, states, step = sess.run(fetches=[train_op, summary_op, final_states, global_step], 
                                                feed_dict=feed_dict)
            writer.add_summary(summary=summary, global_step=step)
            if ((step + 1) % save_every) == 0:
                saver.save(sess=sess, save_path=model_checkpoint_path, global_step=step)
            if (step % visualize_every) == 0:
                feed_dict = {seq_input: x[:1, : , :]}
                for key, value in zip(init_state, zero_states):
                    feed_dict[key] = value
                sample_output = sess.run(seq_output, feed_dict=feed_dict)
                print(test_codec.decode(sample_output[0], strip=False))
        sess.run(tf.assign(lr, lr.eval(sess) * learning_rate_decay))


                                                                                                    
                                                                                                    
dd                                                                                                  
dd                                                                                                  
fd                                                                                                  
ff                                                                                                  
fd                                                                                                  
f                                                                                                   
f                                                                                                   
f                                                                                                   
                                                                                                    
                                                                                                    
                                                                                                    
                                                                                                    
f                                                                                                   
                                                                                                    
                                                                                                    
                                                                                                    
                                                                                                    
f                                                                                                   
                                                                                                    
                                                                                                    
                                                                                                    
                                                                                                    
                                                                                                    
                                                                                                    
                                                                                                    
                                                                                                    
                                                                                                    
                                                                                                    
                                                                                                    
                                                                                                    

Test online inference


In [ ]:
def online_inference(cell, prime, sess, codec, 
                     input_tensor, 
                     init_state_tensor_tuple, 
                     output_tensor, 
                     final_state_tensor_tuple, 
                     length):
    final_output = [prime]
    zero_states = sess.run(cell.zero_state(batch_size=1, dtype=tf.float32))
    feed_dict = {input_tensor: codec.encode(prime, start=True, stop=False)[np.newaxis, :, :]}  # prime
    for init_state_tensor, init_state_value in zip(init_state_tensor_tuple, 
                                                   zero_states):
        feed_dict[init_state_tensor] = init_state_value
    output, final_states = sess.run([output_tensor, final_state_tensor_tuple], feed_dict=feed_dict)
    final_output.append(codec.decode(output[0, -1:, :], strip=False))
    for _ in range(length - len(prime)):
        feed_dict = {input_tensor: codec.encode(final_output[-1], start=False, stop=False)[np.newaxis, :, :]}
        for init_state_tensor, init_state_value in zip(init_state_tensor_tuple, final_states):
            feed_dict[init_state_tensor] = init_state_value
        output, final_states = sess.run([output_tensor, final_state_tensor_tuple], feed_dict=feed_dict)
        final_output.append(codec.decode(output[0], strip=False))
    return "".join(final_output)

In [ ]:
saver = tf.train.Saver()
with tf.Session() as sess:
    ckpt = saver.last_checkpoints
    print(ckpt)
    print(online_inference(rnn_cell, "We propose", 
                       sess, test_codec, seq_input, init_state, seq_output, final_states, 200))

In [ ]: