Anna KaRNNa

In this notebook, I'll build a character-wise RNN trained on Anna Karenina, one of my all-time favorite books. It'll be able to generate new text based on the text from the book.

This network is based off of Andrej Karpathy's post on RNNs and implementation in Torch. Also, some information here at r2rt and from Sherjil Ozair on GitHub. Below is the general architecture of the character-wise RNN.


In [1]:
import time
from collections import namedtuple

import numpy as np
import tensorflow as tf

First we'll load the text file and convert it into integers for our network to use.


In [2]:
with open('anna.txt', 'r') as f:
    text=f.read()
vocab = set(text)
vocab_to_int = {c: i for i, c in enumerate(vocab)}
int_to_vocab = dict(enumerate(vocab))
chars = np.array([vocab_to_int[c] for c in text], dtype=np.int32)

In [3]:
text[:100]


Out[3]:
'Chapter 1\n\n\nHappy families are all alike; every unhappy family is unhappy in its own\nway.\n\nEverythin'

In [4]:
chars[:100]


Out[4]:
array([64, 50,  8, 60, 69,  1, 27, 35, 10, 49, 49, 49, 51,  8, 60, 60, 28,
       35, 16,  8, 14,  4, 68,  4,  1, 13, 35,  8, 27,  1, 35,  8, 68, 68,
       35,  8, 68,  4, 70,  1, 71, 35,  1, 74,  1, 27, 28, 35, 41, 43, 50,
        8, 60, 60, 28, 35, 16,  8, 14,  4, 68, 28, 35,  4, 13, 35, 41, 43,
       50,  8, 60, 60, 28, 35,  4, 43, 35,  4, 69, 13, 35, 11, 34, 43, 49,
       34,  8, 28, 52, 49, 49, 23, 74,  1, 27, 28, 69, 50,  4, 43], dtype=int32)

Now I need to split up the data into batches, and into training and validation sets. I should be making a test set here, but I'm not going to worry about that. My test will be if the network can generate new text.

Here I'll make both input and target arrays. The targets are the same as the inputs, except shifted one character over. I'll also drop the last bit of data so that I'll only have completely full batches.

The idea here is to make a 2D matrix where the number of rows is equal to the number of batches. Each row will be one long concatenated string from the character data. We'll split this data into a training set and validation set using the split_frac keyword. This will keep 90% of the batches in the training set, the other 10% in the validation set.


In [5]:
def split_data(chars, batch_size, num_steps, split_frac=0.9):
    """ 
    Split character data into training and validation sets, inputs and targets for each set.
    
    Arguments
    ---------
    chars: character array
    batch_size: Size of examples in each of batch
    num_steps: Number of sequence steps to keep in the input and pass to the network
    split_frac: Fraction of batches to keep in the training set
    
    
    Returns train_x, train_y, val_x, val_y
    """
    
    slice_size = batch_size * num_steps
    n_batches = int(len(chars) / slice_size)
    
    # Drop the last few characters to make only full batches
    x = chars[: n_batches*slice_size]
    y = chars[1: n_batches*slice_size + 1]
    
    # Split the data into batch_size slices, then stack them into a 2D matrix 
    x = np.stack(np.split(x, batch_size))
    y = np.stack(np.split(y, batch_size))
    
    # Now x and y are arrays with dimensions batch_size x n_batches*num_steps
    
    # Split into training and validation sets, keep the virst split_frac batches for training
    split_idx = int(n_batches*split_frac)
    train_x, train_y= x[:, :split_idx*num_steps], y[:, :split_idx*num_steps]
    val_x, val_y = x[:, split_idx*num_steps:], y[:, split_idx*num_steps:]
    
    return train_x, train_y, val_x, val_y

In [6]:
train_x, train_y, val_x, val_y = split_data(chars, 10, 200)

In [7]:
train_x.shape


Out[7]:
(10, 178400)

In [8]:
train_x[:,:10]


Out[8]:
array([[64, 50,  8, 60, 69,  1, 27, 35, 10, 49],
       [ 0, 43,  6, 35, 50,  1, 35, 14, 11, 74],
       [35,  2,  8, 69,  2, 50,  4, 43, 42, 35],
       [11, 69, 50,  1, 27, 35, 34, 11, 41, 68],
       [35, 69, 50,  1, 35, 68,  8, 43,  6, 32],
       [35,  5, 50, 27, 11, 41, 42, 50, 35, 68],
       [69, 35, 69, 11, 49,  6, 11, 52, 49, 49],
       [11, 35, 50,  1, 27, 13,  1, 68, 16, 20],
       [50,  8, 69, 35,  4, 13, 35, 69, 50,  1],
       [ 1, 27, 13,  1, 68, 16, 35,  8, 43,  6]], dtype=int32)

I'll write another function to grab batches out of the arrays made by split data. Here each batch will be a sliding window on these arrays with size batch_size X num_steps. For example, if we want our network to train on a sequence of 100 characters, num_steps = 100. For the next batch, we'll shift this window the next sequence of num_steps characters. In this way we can feed batches to the network and the cell states will continue through on each batch.


In [9]:
def get_batch(arrs, num_steps):
    batch_size, slice_size = arrs[0].shape
    
    n_batches = int(slice_size/num_steps)
    for b in range(n_batches):
        yield [x[:, b*num_steps: (b+1)*num_steps] for x in arrs]

In [10]:
def build_rnn(num_classes, batch_size=50, num_steps=50, lstm_size=128, num_layers=2,
              learning_rate=0.001, grad_clip=5, sampling=False):
        
    if sampling == True:
        batch_size, num_steps = 1, 1

    tf.reset_default_graph()
    
    # Declare placeholders we'll feed into the graph
    with tf.name_scope('inputs'):
        inputs = tf.placeholder(tf.int32, [batch_size, num_steps], name='inputs')
        x_one_hot = tf.one_hot(inputs, num_classes, name='x_one_hot')
    
    with tf.name_scope('targets'):
        targets = tf.placeholder(tf.int32, [batch_size, num_steps], name='targets')
        y_one_hot = tf.one_hot(targets, num_classes, name='y_one_hot')
        y_reshaped = tf.reshape(y_one_hot, [-1, num_classes])
    
    keep_prob = tf.placeholder(tf.float32, name='keep_prob')
    
    # Build the RNN layers
    with tf.name_scope("RNN_cells"):
        lstm = tf.contrib.rnn.BasicLSTMCell(lstm_size)
        drop = tf.contrib.rnn.DropoutWrapper(lstm, output_keep_prob=keep_prob)
        cell = tf.contrib.rnn.MultiRNNCell([drop] * num_layers)
    
    with tf.name_scope("RNN_init_state"):
        initial_state = cell.zero_state(batch_size, tf.float32)

    # Run the data through the RNN layers
    with tf.name_scope("RNN_forward"):
        outputs, state = tf.nn.dynamic_rnn(cell, x_one_hot, initial_state=initial_state)
    
    final_state = state
    
    # Reshape output so it's a bunch of rows, one row for each cell output
    with tf.name_scope('sequence_reshape'):
        seq_output = tf.concat(outputs, axis=1,name='seq_output')
        output = tf.reshape(seq_output, [-1, lstm_size], name='graph_output')
    
    # Now connect the RNN outputs to a softmax layer and calculate the cost
    with tf.name_scope('logits'):
        softmax_w = tf.Variable(tf.truncated_normal((lstm_size, num_classes), stddev=0.1),
                               name='softmax_w')
        softmax_b = tf.Variable(tf.zeros(num_classes), name='softmax_b')
        logits = tf.matmul(output, softmax_w) + softmax_b
        tf.summary.histogram('softmax_w', softmax_w)
        tf.summary.histogram('softmax_b', softmax_b)

    with tf.name_scope('predictions'):
        preds = tf.nn.softmax(logits, name='predictions')
        tf.summary.histogram('predictions', preds)
    
    with tf.name_scope('cost'):
        loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y_reshaped, name='loss')
        cost = tf.reduce_mean(loss, name='cost')
        tf.summary.scalar('cost', cost)

    # Optimizer for training, using gradient clipping to control exploding gradients
    with tf.name_scope('train'):
        tvars = tf.trainable_variables()
        grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars), grad_clip)
        train_op = tf.train.AdamOptimizer(learning_rate)
        optimizer = train_op.apply_gradients(zip(grads, tvars))
    
    merged = tf.summary.merge_all()
    
    # Export the nodes 
    export_nodes = ['inputs', 'targets', 'initial_state', 'final_state',
                    'keep_prob', 'cost', 'preds', 'optimizer', 'merged']
    Graph = namedtuple('Graph', export_nodes)
    local_dict = locals()
    graph = Graph(*[local_dict[each] for each in export_nodes])
    
    return graph

Hyperparameters

Here I'm defining the hyperparameters for the network. The two you probably haven't seen before are lstm_size and num_layers. These set the number of hidden units in the LSTM layers and the number of LSTM layers, respectively. Of course, making these bigger will improve the network's performance but you'll have to watch out for overfitting. If your validation loss is much larger than the training loss, you're probably overfitting. Decrease the size of the network or decrease the dropout keep probability.


In [11]:
batch_size = 100
num_steps = 100
lstm_size = 512
num_layers = 2
learning_rate = 0.001

Training

Time for training which is is pretty straightforward. Here I pass in some data, and get an LSTM state back. Then I pass that state back in to the network so the next batch can continue the state from the previous batch. And every so often (set by save_every_n) I calculate the validation loss and save a checkpoint.


In [12]:
!mkdir -p checkpoints/anna

In [13]:
epochs = 10
save_every_n = 100
train_x, train_y, val_x, val_y = split_data(chars, batch_size, num_steps)

model = build_rnn(len(vocab), 
                  batch_size=batch_size,
                  num_steps=num_steps,
                  learning_rate=learning_rate,
                  lstm_size=lstm_size,
                  num_layers=num_layers)

saver = tf.train.Saver(max_to_keep=100)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    train_writer = tf.summary.FileWriter('./logs/2/train', sess.graph)
    test_writer = tf.summary.FileWriter('./logs/2/test')
    
    # Use the line below to load a checkpoint and resume training
    #saver.restore(sess, 'checkpoints/anna20.ckpt')
    
    n_batches = int(train_x.shape[1]/num_steps)
    iterations = n_batches * epochs
    for e in range(epochs):
        
        # Train network
        new_state = sess.run(model.initial_state)
        loss = 0
        for b, (x, y) in enumerate(get_batch([train_x, train_y], num_steps), 1):
            iteration = e*n_batches + b
            start = time.time()
            feed = {model.inputs: x,
                    model.targets: y,
                    model.keep_prob: 0.5,
                    model.initial_state: new_state}
            summary, batch_loss, new_state, _ = sess.run([model.merged, model.cost, 
                                                          model.final_state, model.optimizer], 
                                                          feed_dict=feed)
            loss += batch_loss
            end = time.time()
            print('Epoch {}/{} '.format(e+1, epochs),
                  'Iteration {}/{}'.format(iteration, iterations),
                  'Training loss: {:.4f}'.format(loss/b),
                  '{:.4f} sec/batch'.format((end-start)))
            
            train_writer.add_summary(summary, iteration)
        
            if (iteration%save_every_n == 0) or (iteration == iterations):
                # Check performance, notice dropout has been set to 1
                val_loss = []
                new_state = sess.run(model.initial_state)
                for x, y in get_batch([val_x, val_y], num_steps):
                    feed = {model.inputs: x,
                            model.targets: y,
                            model.keep_prob: 1.,
                            model.initial_state: new_state}
                    summary, batch_loss, new_state = sess.run([model.merged, model.cost, 
                                                               model.final_state], feed_dict=feed)
                    val_loss.append(batch_loss)
                    
                test_writer.add_summary(summary, iteration)

                print('Validation loss:', np.mean(val_loss),
                      'Saving checkpoint!')
                #saver.save(sess, "checkpoints/anna/i{}_l{}_{:.3f}.ckpt".format(iteration, lstm_size, np.mean(val_loss)))


Epoch 1/10  Iteration 1/1780 Training loss: 4.4168 1.8580 sec/batch
Epoch 1/10  Iteration 2/1780 Training loss: 4.3664 0.4135 sec/batch
Epoch 1/10  Iteration 3/1780 Training loss: 4.1648 0.4161 sec/batch
Epoch 1/10  Iteration 4/1780 Training loss: 4.2122 0.4184 sec/batch
Epoch 1/10  Iteration 5/1780 Training loss: 4.1397 0.4204 sec/batch
Epoch 1/10  Iteration 6/1780 Training loss: 4.0702 0.4154 sec/batch
Epoch 1/10  Iteration 7/1780 Training loss: 3.9940 0.4164 sec/batch
Epoch 1/10  Iteration 8/1780 Training loss: 3.9251 0.4157 sec/batch
Epoch 1/10  Iteration 9/1780 Training loss: 3.8664 0.4202 sec/batch
Epoch 1/10  Iteration 10/1780 Training loss: 3.8158 0.4179 sec/batch
Epoch 1/10  Iteration 11/1780 Training loss: 3.7705 0.4266 sec/batch
Epoch 1/10  Iteration 12/1780 Training loss: 3.7329 0.4204 sec/batch
Epoch 1/10  Iteration 13/1780 Training loss: 3.6993 0.4151 sec/batch
Epoch 1/10  Iteration 14/1780 Training loss: 3.6699 0.4162 sec/batch
Epoch 1/10  Iteration 15/1780 Training loss: 3.6434 0.4162 sec/batch
Epoch 1/10  Iteration 16/1780 Training loss: 3.6203 0.4146 sec/batch
Epoch 1/10  Iteration 17/1780 Training loss: 3.5985 0.4162 sec/batch
Epoch 1/10  Iteration 18/1780 Training loss: 3.5807 0.4167 sec/batch
Epoch 1/10  Iteration 19/1780 Training loss: 3.5632 0.4341 sec/batch
Epoch 1/10  Iteration 20/1780 Training loss: 3.5462 0.4348 sec/batch
Epoch 1/10  Iteration 21/1780 Training loss: 3.5314 0.4345 sec/batch
Epoch 1/10  Iteration 22/1780 Training loss: 3.5177 0.4352 sec/batch
Epoch 1/10  Iteration 23/1780 Training loss: 3.5047 0.4073 sec/batch
Epoch 1/10  Iteration 24/1780 Training loss: 3.4925 0.4114 sec/batch
Epoch 1/10  Iteration 25/1780 Training loss: 3.4811 0.4123 sec/batch
Epoch 1/10  Iteration 26/1780 Training loss: 3.4707 0.4204 sec/batch
Epoch 1/10  Iteration 27/1780 Training loss: 3.4615 0.4401 sec/batch
Epoch 1/10  Iteration 28/1780 Training loss: 3.4516 0.4381 sec/batch
Epoch 1/10  Iteration 29/1780 Training loss: 3.4432 0.4152 sec/batch
Epoch 1/10  Iteration 30/1780 Training loss: 3.4352 0.4104 sec/batch
Epoch 1/10  Iteration 31/1780 Training loss: 3.4284 0.4107 sec/batch
Epoch 1/10  Iteration 32/1780 Training loss: 3.4209 0.4177 sec/batch
Epoch 1/10  Iteration 33/1780 Training loss: 3.4134 0.4383 sec/batch
Epoch 1/10  Iteration 34/1780 Training loss: 3.4074 0.4246 sec/batch
Epoch 1/10  Iteration 35/1780 Training loss: 3.4008 0.4359 sec/batch
Epoch 1/10  Iteration 36/1780 Training loss: 3.3952 0.4393 sec/batch
Epoch 1/10  Iteration 37/1780 Training loss: 3.3889 0.4382 sec/batch
Epoch 1/10  Iteration 38/1780 Training loss: 3.3829 0.4360 sec/batch
Epoch 1/10  Iteration 39/1780 Training loss: 3.3772 0.4239 sec/batch
Epoch 1/10  Iteration 40/1780 Training loss: 3.3719 0.4092 sec/batch
Epoch 1/10  Iteration 41/1780 Training loss: 3.3666 0.4057 sec/batch
Epoch 1/10  Iteration 42/1780 Training loss: 3.3616 0.4069 sec/batch
Epoch 1/10  Iteration 43/1780 Training loss: 3.3567 0.4101 sec/batch
Epoch 1/10  Iteration 44/1780 Training loss: 3.3520 0.4080 sec/batch
Epoch 1/10  Iteration 45/1780 Training loss: 3.3474 0.4089 sec/batch
Epoch 1/10  Iteration 46/1780 Training loss: 3.3433 0.4065 sec/batch
Epoch 1/10  Iteration 47/1780 Training loss: 3.3394 0.4146 sec/batch
Epoch 1/10  Iteration 48/1780 Training loss: 3.3358 0.4083 sec/batch
Epoch 1/10  Iteration 49/1780 Training loss: 3.3323 0.4067 sec/batch
Epoch 1/10  Iteration 50/1780 Training loss: 3.3289 0.4074 sec/batch
Epoch 1/10  Iteration 51/1780 Training loss: 3.3255 0.4073 sec/batch
Epoch 1/10  Iteration 52/1780 Training loss: 3.3220 0.4081 sec/batch
Epoch 1/10  Iteration 53/1780 Training loss: 3.3189 0.4051 sec/batch
Epoch 1/10  Iteration 54/1780 Training loss: 3.3155 0.4049 sec/batch
Epoch 1/10  Iteration 55/1780 Training loss: 3.3125 0.4093 sec/batch
Epoch 1/10  Iteration 56/1780 Training loss: 3.3093 0.4052 sec/batch
Epoch 1/10  Iteration 57/1780 Training loss: 3.3064 0.4096 sec/batch
Epoch 1/10  Iteration 58/1780 Training loss: 3.3036 0.4063 sec/batch
Epoch 1/10  Iteration 59/1780 Training loss: 3.3006 0.4047 sec/batch
Epoch 1/10  Iteration 60/1780 Training loss: 3.2980 0.4074 sec/batch
Epoch 1/10  Iteration 61/1780 Training loss: 3.2954 0.4050 sec/batch
Epoch 1/10  Iteration 62/1780 Training loss: 3.2932 0.4101 sec/batch
Epoch 1/10  Iteration 63/1780 Training loss: 3.2911 0.4070 sec/batch
Epoch 1/10  Iteration 64/1780 Training loss: 3.2884 0.4072 sec/batch
Epoch 1/10  Iteration 65/1780 Training loss: 3.2859 0.4101 sec/batch
Epoch 1/10  Iteration 66/1780 Training loss: 3.2838 0.4057 sec/batch
Epoch 1/10  Iteration 67/1780 Training loss: 3.2817 0.4091 sec/batch
Epoch 1/10  Iteration 68/1780 Training loss: 3.2788 0.4054 sec/batch
Epoch 1/10  Iteration 69/1780 Training loss: 3.2765 0.4071 sec/batch
Epoch 1/10  Iteration 70/1780 Training loss: 3.2744 0.4077 sec/batch
Epoch 1/10  Iteration 71/1780 Training loss: 3.2723 0.4051 sec/batch
Epoch 1/10  Iteration 72/1780 Training loss: 3.2705 0.4073 sec/batch
Epoch 1/10  Iteration 73/1780 Training loss: 3.2684 0.4054 sec/batch
Epoch 1/10  Iteration 74/1780 Training loss: 3.2664 0.4054 sec/batch
Epoch 1/10  Iteration 75/1780 Training loss: 3.2646 0.4124 sec/batch
Epoch 1/10  Iteration 76/1780 Training loss: 3.2629 0.4076 sec/batch
Epoch 1/10  Iteration 77/1780 Training loss: 3.2610 0.4078 sec/batch
Epoch 1/10  Iteration 78/1780 Training loss: 3.2592 0.4056 sec/batch
Epoch 1/10  Iteration 79/1780 Training loss: 3.2574 0.4091 sec/batch
Epoch 1/10  Iteration 80/1780 Training loss: 3.2554 0.4076 sec/batch
Epoch 1/10  Iteration 81/1780 Training loss: 3.2534 0.4055 sec/batch
Epoch 1/10  Iteration 82/1780 Training loss: 3.2517 0.4089 sec/batch
Epoch 1/10  Iteration 83/1780 Training loss: 3.2501 0.4048 sec/batch
Epoch 1/10  Iteration 84/1780 Training loss: 3.2483 0.4096 sec/batch
Epoch 1/10  Iteration 85/1780 Training loss: 3.2463 0.4053 sec/batch
Epoch 1/10  Iteration 86/1780 Training loss: 3.2445 0.4053 sec/batch
Epoch 1/10  Iteration 87/1780 Training loss: 3.2428 0.4090 sec/batch
Epoch 1/10  Iteration 88/1780 Training loss: 3.2410 0.4054 sec/batch
Epoch 1/10  Iteration 89/1780 Training loss: 3.2395 0.4094 sec/batch
Epoch 1/10  Iteration 90/1780 Training loss: 3.2379 0.4075 sec/batch
Epoch 1/10  Iteration 91/1780 Training loss: 3.2364 0.4059 sec/batch
Epoch 1/10  Iteration 92/1780 Training loss: 3.2347 0.4171 sec/batch
Epoch 1/10  Iteration 93/1780 Training loss: 3.2330 0.4212 sec/batch
Epoch 1/10  Iteration 94/1780 Training loss: 3.2315 0.4086 sec/batch
Epoch 1/10  Iteration 95/1780 Training loss: 3.2298 0.4072 sec/batch
Epoch 1/10  Iteration 96/1780 Training loss: 3.2281 0.4072 sec/batch
Epoch 1/10  Iteration 97/1780 Training loss: 3.2265 0.4139 sec/batch
Epoch 1/10  Iteration 98/1780 Training loss: 3.2247 0.4072 sec/batch
Epoch 1/10  Iteration 99/1780 Training loss: 3.2231 0.4091 sec/batch
Epoch 1/10  Iteration 100/1780 Training loss: 3.2214 0.4048 sec/batch
Validation loss: 3.0025 Saving checkpoint!
Epoch 1/10  Iteration 101/1780 Training loss: 3.2196 0.4075 sec/batch
Epoch 1/10  Iteration 102/1780 Training loss: 3.2179 0.4193 sec/batch
Epoch 1/10  Iteration 103/1780 Training loss: 3.2161 0.4078 sec/batch
Epoch 1/10  Iteration 104/1780 Training loss: 3.2141 0.4094 sec/batch
Epoch 1/10  Iteration 105/1780 Training loss: 3.2123 0.4088 sec/batch
Epoch 1/10  Iteration 106/1780 Training loss: 3.2105 0.4109 sec/batch
Epoch 1/10  Iteration 107/1780 Training loss: 3.2085 0.4083 sec/batch
Epoch 1/10  Iteration 108/1780 Training loss: 3.2066 0.4076 sec/batch
Epoch 1/10  Iteration 109/1780 Training loss: 3.2048 0.4109 sec/batch
Epoch 1/10  Iteration 110/1780 Training loss: 3.2026 0.4276 sec/batch
Epoch 1/10  Iteration 111/1780 Training loss: 3.2005 0.4406 sec/batch
Epoch 1/10  Iteration 112/1780 Training loss: 3.1984 0.4086 sec/batch
Epoch 1/10  Iteration 113/1780 Training loss: 3.1963 0.4078 sec/batch
Epoch 1/10  Iteration 114/1780 Training loss: 3.1940 0.4171 sec/batch
Epoch 1/10  Iteration 115/1780 Training loss: 3.1917 0.4269 sec/batch
Epoch 1/10  Iteration 116/1780 Training loss: 3.1898 0.4346 sec/batch
Epoch 1/10  Iteration 117/1780 Training loss: 3.1877 0.4307 sec/batch
Epoch 1/10  Iteration 118/1780 Training loss: 3.1858 0.4334 sec/batch
Epoch 1/10  Iteration 119/1780 Training loss: 3.1839 0.4075 sec/batch
Epoch 1/10  Iteration 120/1780 Training loss: 3.1817 0.4149 sec/batch
Epoch 1/10  Iteration 121/1780 Training loss: 3.1799 0.4114 sec/batch
Epoch 1/10  Iteration 122/1780 Training loss: 3.1777 0.4300 sec/batch
Epoch 1/10  Iteration 123/1780 Training loss: 3.1755 0.4181 sec/batch
Epoch 1/10  Iteration 124/1780 Training loss: 3.1734 0.4063 sec/batch
Epoch 1/10  Iteration 125/1780 Training loss: 3.1710 0.4093 sec/batch
Epoch 1/10  Iteration 126/1780 Training loss: 3.1683 0.4142 sec/batch
Epoch 1/10  Iteration 127/1780 Training loss: 3.1659 0.4083 sec/batch
Epoch 1/10  Iteration 128/1780 Training loss: 3.1635 0.4084 sec/batch
Epoch 1/10  Iteration 129/1780 Training loss: 3.1610 0.4076 sec/batch
Epoch 1/10  Iteration 130/1780 Training loss: 3.1584 0.4082 sec/batch
Epoch 1/10  Iteration 131/1780 Training loss: 3.1558 0.4074 sec/batch
Epoch 1/10  Iteration 132/1780 Training loss: 3.1531 0.4138 sec/batch
Epoch 1/10  Iteration 133/1780 Training loss: 3.1504 0.4078 sec/batch
Epoch 1/10  Iteration 134/1780 Training loss: 3.1477 0.4076 sec/batch
Epoch 1/10  Iteration 135/1780 Training loss: 3.1447 0.4083 sec/batch
Epoch 1/10  Iteration 136/1780 Training loss: 3.1420 0.4069 sec/batch
Epoch 1/10  Iteration 137/1780 Training loss: 3.1396 0.4082 sec/batch
Epoch 1/10  Iteration 138/1780 Training loss: 3.1370 0.4111 sec/batch
Epoch 1/10  Iteration 139/1780 Training loss: 3.1347 0.4082 sec/batch
Epoch 1/10  Iteration 140/1780 Training loss: 3.1320 0.4202 sec/batch
Epoch 1/10  Iteration 141/1780 Training loss: 3.1294 0.4176 sec/batch
Epoch 1/10  Iteration 142/1780 Training loss: 3.1267 0.4137 sec/batch
Epoch 1/10  Iteration 143/1780 Training loss: 3.1240 0.4157 sec/batch
Epoch 1/10  Iteration 144/1780 Training loss: 3.1212 0.4124 sec/batch
Epoch 1/10  Iteration 145/1780 Training loss: 3.1185 0.4137 sec/batch
Epoch 1/10  Iteration 146/1780 Training loss: 3.1158 0.4077 sec/batch
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-13-49ddafcbde86> in <module>()
     36             summary, batch_loss, new_state, _ = sess.run([model.merged, model.cost, 
     37                                                           model.final_state, model.optimizer], 
---> 38                                                           feed_dict=feed)
     39             loss += batch_loss
     40             end = time.time()

/home/luo/anaconda2/envs/dlnd/lib/python3.6/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    765     try:
    766       result = self._run(None, fetches, feed_dict, options_ptr,
--> 767                          run_metadata_ptr)
    768       if run_metadata:
    769         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/home/luo/anaconda2/envs/dlnd/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
    963     if final_fetches or final_targets:
    964       results = self._do_run(handle, final_targets, final_fetches,
--> 965                              feed_dict_string, options, run_metadata)
    966     else:
    967       results = []

/home/luo/anaconda2/envs/dlnd/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1013     if handle is None:
   1014       return self._do_call(_run_fn, self._session, feed_dict, fetch_list,
-> 1015                            target_list, options, run_metadata)
   1016     else:
   1017       return self._do_call(_prun_fn, self._session, handle, feed_dict,

/home/luo/anaconda2/envs/dlnd/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
   1020   def _do_call(self, fn, *args):
   1021     try:
-> 1022       return fn(*args)
   1023     except errors.OpError as e:
   1024       message = compat.as_text(e.message)

/home/luo/anaconda2/envs/dlnd/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
   1002         return tf_session.TF_Run(session, options,
   1003                                  feed_dict, fetch_list, target_list,
-> 1004                                  status, run_metadata)
   1005 
   1006     def _prun_fn(session, handle, feed_dict, fetch_list):

KeyboardInterrupt: 

In [14]:
tf.train.get_checkpoint_state('checkpoints/anna')

Sampling

Now that the network is trained, we'll can use it to generate new text. The idea is that we pass in a character, then the network will predict the next character. We can use the new one, to predict the next one. And we keep doing this to generate all new text. I also included some functionality to prime the network with some text by passing in a string and building up a state from that.

The network gives us predictions for each character. To reduce noise and make things a little less random, I'm going to only choose a new character from the top N most likely characters.


In [17]:
def pick_top_n(preds, vocab_size, top_n=5):
    p = np.squeeze(preds)
    p[np.argsort(p)[:-top_n]] = 0
    p = p / np.sum(p)
    c = np.random.choice(vocab_size, 1, p=p)[0]
    return c

In [41]:
def sample(checkpoint, n_samples, lstm_size, vocab_size, prime="The "):
    prime = "Far"
    samples = [c for c in prime]
    model = build_rnn(vocab_size, lstm_size=lstm_size, sampling=True)
    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess, checkpoint)
        new_state = sess.run(model.initial_state)
        for c in prime:
            x = np.zeros((1, 1))
            x[0,0] = vocab_to_int[c]
            feed = {model.inputs: x,
                    model.keep_prob: 1.,
                    model.initial_state: new_state}
            preds, new_state = sess.run([model.preds, model.final_state], 
                                         feed_dict=feed)

        c = pick_top_n(preds, len(vocab))
        samples.append(int_to_vocab[c])

        for i in range(n_samples):
            x[0,0] = c
            feed = {model.inputs: x,
                    model.keep_prob: 1.,
                    model.initial_state: new_state}
            preds, new_state = sess.run([model.preds, model.final_state], 
                                         feed_dict=feed)

            c = pick_top_n(preds, len(vocab))
            samples.append(int_to_vocab[c])
        
    return ''.join(samples)

In [44]:
checkpoint = "checkpoints/anna/i3560_l512_1.122.ckpt"
samp = sample(checkpoint, 2000, lstm_size, len(vocab), prime="Far")
print(samp)


Farlathit that if had so
like it that it were. He could not trouble to his wife, and there was
anything in them of the side of his weaky in the creature at his forteren
to him.

"What is it? I can't bread to those," said Stepan Arkadyevitch. "It's not
my children, and there is an almost this arm, true it mays already,
and tell you what I have say to you, and was not looking at the peasant,
why is, I don't know him out, and she doesn't speak to me immediately, as
you would say the countess and the more frest an angelembre, and time and
things's silent, but I was not in my stand that is in my head. But if he
say, and was so feeling with his soul. A child--in his soul of his
soul of his soul. He should not see that any of that sense of. Here he
had not been so composed and to speak for as in a whole picture, but
all the setting and her excellent and society, who had been delighted
and see to anywing had been being troed to thousand words on them,
we liked him.

That set in her money at the table, he came into the party. The capable
of his she could not be as an old composure.

"That's all something there will be down becime by throe is
such a silent, as in a countess, I should state it out and divorct.
The discussion is not for me. I was that something was simply they are
all three manshess of a sensitions of mind it all."

"No," he thought, shouted and lifting his soul. "While it might see your
honser and she, I could burst. And I had been a midelity. And I had a
marnief are through the countess," he said, looking at him, a chosing
which they had been carried out and still solied, and there was a sen that
was to be completely, and that this matter of all the seconds of it, and
a concipation were to her husband, who came up and conscaously, that he
was not the station. All his fourse she was always at the country,,
to speak oft, and though they were to hear the delightful throom and
whether they came towards the morning, and his living and a coller and
hold--the children. 

In [43]:
checkpoint = "checkpoints/anna/i200_l512_2.432.ckpt"
samp = sample(checkpoint, 1000, lstm_size, len(vocab), prime="Far")
print(samp)


Farnt him oste wha sorind thans tout thint asd an sesand an hires on thime sind thit aled, ban thand and out hore as the ter hos ton ho te that, was tis tart al the hand sostint him sore an tit an son thes, win he se ther san ther hher tas tarereng,.

Anl at an ades in ond hesiln, ad hhe torers teans, wast tar arering tho this sos alten sorer has hhas an siton ther him he had sin he ard ate te anling the sosin her ans and
arins asd and ther ale te tot an tand tanginge wath and ho ald, so sot th asend sat hare sother horesinnd, he hesense wing ante her so tith tir sherinn, anded and to the toul anderin he sorit he torsith she se atere an ting ot hand and thit hhe so the te wile har
ens ont in the sersise, and we he seres tar aterer, to ato tat or has he he wan ton here won and sen heren he sosering, to to theer oo adent har herere the wosh oute, was serild ward tous hed astend..

I's sint on alt in har tor tit her asd hade shithans ored he talereng an soredendere tim tot hees. Tise sor and 

In [46]:
checkpoint = "checkpoints/anna/i600_l512_1.750.ckpt"
samp = sample(checkpoint, 1000, lstm_size, len(vocab), prime="Far")
print(samp)


Fard as astice her said he celatice of to seress in the raice, and to be the some and sere allats to that said to that the sark and a cast a the wither ald the pacinesse of her had astition, he said to the sount as she west at hissele. Af the cond it he was a fact onthis astisarianing.


"Or a ton to to be that's a more at aspestale as the sont of anstiring as
thours and trey.

The same wo dangring the
raterst, who sore and somethy had ast out an of his book. "We had's beane were that, and a morted a thay he had to tere. Then to
her homent andertersed his his ancouted to the pirsted, the soution for of the pirsice inthirgest and stenciol, with the hard and and
a colrice of to be oneres,
the song to this anderssad.
The could ounterss the said to serom of
soment a carsed of sheres of she
torded
har and want in their of hould, but
her told in that in he tad a the same to her. Serghing an her has and with the seed, and the camt ont his about of the
sail, the her then all houg ant or to hus to 

In [47]:
checkpoint = "checkpoints/anna/i1000_l512_1.484.ckpt"
samp = sample(checkpoint, 1000, lstm_size, len(vocab), prime="Far")
print(samp)


Farrat, his felt has at it.

"When the pose ther hor exceed
to his sheant was," weat a sime of his sounsed. The coment and the facily that which had began terede a marilicaly whice whether the pose of his hand, at she was alligated herself the same on she had to
taiking to his forthing and streath how to hand
began in a lang at some at it, this he cholded not set all her. "Wo love that is setthing. Him anstering as seen that."

"Yes in the man that say the mare a crances is it?" said Sergazy Ivancatching. "You doon think were somether is ifficult of a mone of
though the most at the countes that the
mean on the come to say the most, to
his feesing of
a man she, whilo he
sained and well, that he would still at to said. He wind at his for the sore in the most
of hoss and almoved to see him. They have betine the sumper into at he his stire, and what he was that at the so steate of the
sound, and shin should have a geest of shall feet on the conderation to she had been at that imporsing the dre