Anna KaRNNa

In this notebook, I'll build a character-wise RNN trained on Anna Karenina, one of my all-time favorite books. It'll be able to generate new text based on the text from the book.

This network is based off of Andrej Karpathy's post on RNNs and implementation in Torch. Also, some information here at r2rt and from Sherjil Ozair on GitHub. Below is the general architecture of the character-wise RNN.


In [1]:
import time
from collections import namedtuple

import numpy as np
import tensorflow as tf

First we'll load the text file and convert it into integers for our network to use.


In [2]:
with open('anna.txt', 'r') as f:
    text=f.read()
vocab = set(text)
vocab_to_int = {c: i for i, c in enumerate(vocab)}
int_to_vocab = dict(enumerate(vocab))
chars = np.array([vocab_to_int[c] for c in text], dtype=np.int32)

In [3]:
text[:100]


Out[3]:
'Chapter 1\n\n\nHappy families are all alike; every unhappy family is unhappy in its own\nway.\n\nEverythin'

In [4]:
chars[:100]


Out[4]:
array([ 0, 47, 28,  8, 35, 23, 38, 24, 42, 26, 26, 26,  5, 28,  8,  8, 49,
       24, 74, 28, 32, 11, 77, 11, 23, 16, 24, 28, 38, 23, 24, 28, 77, 77,
       24, 28, 77, 11, 63, 23, 79, 24, 23, 72, 23, 38, 49, 24, 65, 33, 47,
       28,  8,  8, 49, 24, 74, 28, 32, 11, 77, 49, 24, 11, 16, 24, 65, 33,
       47, 28,  8,  8, 49, 24, 11, 33, 24, 11, 35, 16, 24, 78,  1, 33, 26,
        1, 28, 49,  9, 26, 26, 31, 72, 23, 38, 49, 35, 47, 11, 33])

Now I need to split up the data into batches, and into training and validation sets. I should be making a test set here, but I'm not going to worry about that. My test will be if the network can generate new text.

Here I'll make both input and target arrays. The targets are the same as the inputs, except shifted one character over. I'll also drop the last bit of data so that I'll only have completely full batches.

The idea here is to make a 2D matrix where the number of rows is equal to the number of batches. Each row will be one long concatenated string from the character data. We'll split this data into a training set and validation set using the split_frac keyword. This will keep 90% of the batches in the training set, the other 10% in the validation set.


In [5]:
def split_data(chars, batch_size, num_steps, split_frac=0.9):
    """ 
    Split character data into training and validation sets, inputs and targets for each set.
    
    Arguments
    ---------
    chars: character array
    batch_size: Size of examples in each of batch
    num_steps: Number of sequence steps to keep in the input and pass to the network
    split_frac: Fraction of batches to keep in the training set
    
    
    Returns train_x, train_y, val_x, val_y
    """
    
    
    slice_size = batch_size * num_steps
    n_batches = int(len(chars) / slice_size)
    
    # Drop the last few characters to make only full batches
    x = chars[: n_batches*slice_size]
    y = chars[1: n_batches*slice_size + 1]
    
    # Split the data into batch_size slices, then stack them into a 2D matrix 
    x = np.stack(np.split(x, batch_size))
    y = np.stack(np.split(y, batch_size))
    
    # Now x and y are arrays with dimensions batch_size x n_batches*num_steps
    
    # Split into training and validation sets, keep the virst split_frac batches for training
    split_idx = int(n_batches*split_frac)
    train_x, train_y= x[:, :split_idx*num_steps], y[:, :split_idx*num_steps]
    val_x, val_y = x[:, split_idx*num_steps:], y[:, split_idx*num_steps:]
    
    return train_x, train_y, val_x, val_y

In [6]:
train_x, train_y, val_x, val_y = split_data(chars, 10, 200)

In [7]:
train_x.shape


Out[7]:
(10, 178400)

In [8]:
train_x[:,:10]


Out[8]:
array([[ 0, 47, 28,  8, 35, 23, 38, 24, 42, 26],
       [19, 33, 56, 24, 47, 23, 24, 32, 78, 72],
       [24,  3, 28, 35,  3, 47, 11, 33,  2, 24],
       [78, 35, 47, 23, 38, 24,  1, 78, 65, 77],
       [24, 35, 47, 23, 24, 77, 28, 33, 56, 25],
       [24, 27, 47, 38, 78, 65,  2, 47, 24, 77],
       [35, 24, 35, 78, 26, 56, 78,  9, 26, 26],
       [78, 24, 47, 23, 38, 16, 23, 77, 74, 62],
       [47, 28, 35, 24, 11, 16, 24, 35, 47, 23],
       [23, 38, 16, 23, 77, 74, 24, 28, 33, 56]])

I'll write another function to grab batches out of the arrays made by split data. Here each batch will be a sliding window on these arrays with size batch_size X num_steps. For example, if we want our network to train on a sequence of 100 characters, num_steps = 100. For the next batch, we'll shift this window the next sequence of num_steps characters. In this way we can feed batches to the network and the cell states will continue through on each batch.


In [9]:
def get_batch(arrs, num_steps):
    batch_size, slice_size = arrs[0].shape
    
    n_batches = int(slice_size/num_steps)
    for b in range(n_batches):
        yield [x[:, b*num_steps: (b+1)*num_steps] for x in arrs]

In [10]:
def build_rnn(num_classes, batch_size=50, num_steps=50, lstm_size=128, num_layers=2,
              learning_rate=0.001, grad_clip=5, sampling=False):
        
    if sampling == True:
        batch_size, num_steps = 1, 1

    tf.reset_default_graph()
    
    # Declare placeholders we'll feed into the graph
    
    inputs = tf.placeholder(tf.int32, [batch_size, num_steps], name='inputs')
    x_one_hot = tf.one_hot(inputs, num_classes, name='x_one_hot')


    targets = tf.placeholder(tf.int32, [batch_size, num_steps], name='targets')
    y_one_hot = tf.one_hot(targets, num_classes, name='y_one_hot')
    y_reshaped = tf.reshape(y_one_hot, [-1, num_classes])
    
    keep_prob = tf.placeholder(tf.float32, name='keep_prob')
    
    # Build the RNN layers
    
    lstm = tf.contrib.rnn.BasicLSTMCell(lstm_size)
    drop = tf.contrib.rnn.DropoutWrapper(lstm, output_keep_prob=keep_prob)
    cell = tf.contrib.rnn.MultiRNNCell([drop] * num_layers)

    initial_state = cell.zero_state(batch_size, tf.float32)

    # Run the data through the RNN layers
    rnn_inputs = [tf.squeeze(i, squeeze_dims=[1]) for i in tf.split(x_one_hot, num_steps, 1)]
    outputs, state = tf.contrib.rnn.static_rnn(cell, rnn_inputs, initial_state=initial_state)
    
    final_state = state
    
    # Reshape output so it's a bunch of rows, one row for each cell output
    
    seq_output = tf.concat(outputs, axis=1,name='seq_output')
    output = tf.reshape(seq_output, [-1, lstm_size], name='graph_output')
    
    # Now connect the RNN putputs to a softmax layer and calculate the cost
    softmax_w = tf.Variable(tf.truncated_normal((lstm_size, num_classes), stddev=0.1),
                           name='softmax_w')
    softmax_b = tf.Variable(tf.zeros(num_classes), name='softmax_b')
    logits = tf.matmul(output, softmax_w) + softmax_b

    preds = tf.nn.softmax(logits, name='predictions')
    
    loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y_reshaped, name='loss')
    cost = tf.reduce_mean(loss, name='cost')

    # Optimizer for training, using gradient clipping to control exploding gradients
    tvars = tf.trainable_variables()
    grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars), grad_clip)
    train_op = tf.train.AdamOptimizer(learning_rate)
    optimizer = train_op.apply_gradients(zip(grads, tvars))

    # Export the nodes 
    export_nodes = ['inputs', 'targets', 'initial_state', 'final_state',
                    'keep_prob', 'cost', 'preds', 'optimizer']
    Graph = namedtuple('Graph', export_nodes)
    local_dict = locals()
    graph = Graph(*[local_dict[each] for each in export_nodes])
    
    return graph

Hyperparameters

Here I'm defining the hyperparameters for the network. The two you probably haven't seen before are lstm_size and num_layers. These set the number of hidden units in the LSTM layers and the number of LSTM layers, respectively. Of course, making these bigger will improve the network's performance but you'll have to watch out for overfitting. If your validation loss is much larger than the training loss, you're probably overfitting. Decrease the size of the network or decrease the dropout keep probability.


In [11]:
batch_size = 100
num_steps = 100
lstm_size = 512
num_layers = 2
learning_rate = 0.001

Write out the graph for TensorBoard


In [12]:
model = build_rnn(len(vocab),
                  batch_size=batch_size,
                  num_steps=num_steps,
                  learning_rate=learning_rate,
                  lstm_size=lstm_size,
                  num_layers=num_layers)

with tf.Session() as sess:
    
    sess.run(tf.global_variables_initializer())
    file_writer = tf.summary.FileWriter('./logs/1', sess.graph)

Training

Time for training which is is pretty straightforward. Here I pass in some data, and get an LSTM state back. Then I pass that state back in to the network so the next batch can continue the state from the previous batch. And every so often (set by save_every_n) I calculate the validation loss and save a checkpoint.


In [13]:
!mkdir -p checkpoints/anna


The syntax of the command is incorrect.

In [14]:
epochs = 1
save_every_n = 200
train_x, train_y, val_x, val_y = split_data(chars, batch_size, num_steps)

model = build_rnn(len(vocab), 
                  batch_size=batch_size,
                  num_steps=num_steps,
                  learning_rate=learning_rate,
                  lstm_size=lstm_size,
                  num_layers=num_layers)

saver = tf.train.Saver(max_to_keep=100)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    
    # Use the line below to load a checkpoint and resume training
    #saver.restore(sess, 'checkpoints/anna20.ckpt')
    
    n_batches = int(train_x.shape[1]/num_steps)
    iterations = n_batches * epochs
    for e in range(epochs):
        
        # Train network
        new_state = sess.run(model.initial_state)
        loss = 0
        for b, (x, y) in enumerate(get_batch([train_x, train_y], num_steps), 1):
            iteration = e*n_batches + b
            start = time.time()
            feed = {model.inputs: x,
                    model.targets: y,
                    model.keep_prob: 0.5,
                    model.initial_state: new_state}
            batch_loss, new_state, _ = sess.run([model.cost, model.final_state, model.optimizer], 
                                                 feed_dict=feed)
            loss += batch_loss
            end = time.time()
            print('Epoch {}/{} '.format(e+1, epochs),
                  'Iteration {}/{}'.format(iteration, iterations),
                  'Training loss: {:.4f}'.format(loss/b),
                  '{:.4f} sec/batch'.format((end-start)))
        
            
            if (iteration%save_every_n == 0) or (iteration == iterations):
                # Check performance, notice dropout has been set to 1
                val_loss = []
                new_state = sess.run(model.initial_state)
                for x, y in get_batch([val_x, val_y], num_steps):
                    feed = {model.inputs: x,
                            model.targets: y,
                            model.keep_prob: 1.,
                            model.initial_state: new_state}
                    batch_loss, new_state = sess.run([model.cost, model.final_state], feed_dict=feed)
                    val_loss.append(batch_loss)

                print('Validation loss:', np.mean(val_loss),
                      'Saving checkpoint!')
                saver.save(sess, "checkpoints/anna/i{}_l{}_{:.3f}.ckpt".format(iteration, lstm_size, np.mean(val_loss)))


Epoch 1/1  Iteration 1/178 Training loss: 4.4204 5.2430 sec/batch
Epoch 1/1  Iteration 2/178 Training loss: 4.3785 3.5467 sec/batch
Epoch 1/1  Iteration 3/178 Training loss: 4.2356 3.4171 sec/batch
Epoch 1/1  Iteration 4/178 Training loss: 4.6379 3.4191 sec/batch
Epoch 1/1  Iteration 5/178 Training loss: 4.5420 3.3663 sec/batch
Epoch 1/1  Iteration 6/178 Training loss: 4.4081 3.3982 sec/batch
Epoch 1/1  Iteration 7/178 Training loss: 4.2930 3.3872 sec/batch
Epoch 1/1  Iteration 8/178 Training loss: 4.1977 3.5041 sec/batch
Epoch 1/1  Iteration 9/178 Training loss: 4.1130 3.3673 sec/batch
Epoch 1/1  Iteration 10/178 Training loss: 4.0409 3.3733 sec/batch
Epoch 1/1  Iteration 11/178 Training loss: 3.9778 3.3354 sec/batch
Epoch 1/1  Iteration 12/178 Training loss: 3.9246 3.3633 sec/batch
Epoch 1/1  Iteration 13/178 Training loss: 3.8769 3.6076 sec/batch
Epoch 1/1  Iteration 14/178 Training loss: 3.8350 3.4948 sec/batch
Epoch 1/1  Iteration 15/178 Training loss: 3.7970 3.3682 sec/batch
Epoch 1/1  Iteration 16/178 Training loss: 3.7637 3.3663 sec/batch
Epoch 1/1  Iteration 17/178 Training loss: 3.7335 3.4291 sec/batch
Epoch 1/1  Iteration 18/178 Training loss: 3.7090 3.3802 sec/batch
Epoch 1/1  Iteration 19/178 Training loss: 3.6853 3.4819 sec/batch
Epoch 1/1  Iteration 20/178 Training loss: 3.6612 3.3154 sec/batch
Epoch 1/1  Iteration 21/178 Training loss: 3.6405 3.4600 sec/batch
Epoch 1/1  Iteration 22/178 Training loss: 3.6215 3.3792 sec/batch
Epoch 1/1  Iteration 23/178 Training loss: 3.6037 3.3154 sec/batch
Epoch 1/1  Iteration 24/178 Training loss: 3.5876 3.4081 sec/batch
Epoch 1/1  Iteration 25/178 Training loss: 3.5720 3.3139 sec/batch
Epoch 1/1  Iteration 26/178 Training loss: 3.5580 3.2636 sec/batch
Epoch 1/1  Iteration 27/178 Training loss: 3.5452 3.2745 sec/batch
Epoch 1/1  Iteration 28/178 Training loss: 3.5324 3.3005 sec/batch
Epoch 1/1  Iteration 29/178 Training loss: 3.5209 3.3124 sec/batch
Epoch 1/1  Iteration 30/178 Training loss: 3.5105 3.2775 sec/batch
Epoch 1/1  Iteration 31/178 Training loss: 3.5009 3.3443 sec/batch
Epoch 1/1  Iteration 32/178 Training loss: 3.4911 3.2935 sec/batch
Epoch 1/1  Iteration 33/178 Training loss: 3.4815 3.2835 sec/batch
Epoch 1/1  Iteration 34/178 Training loss: 3.4731 3.3733 sec/batch
Epoch 1/1  Iteration 35/178 Training loss: 3.4646 3.3323 sec/batch
Epoch 1/1  Iteration 36/178 Training loss: 3.4570 3.2935 sec/batch
Epoch 1/1  Iteration 37/178 Training loss: 3.4490 3.2665 sec/batch
Epoch 1/1  Iteration 38/178 Training loss: 3.4415 3.3054 sec/batch
Epoch 1/1  Iteration 39/178 Training loss: 3.4343 3.3383 sec/batch
Epoch 1/1  Iteration 40/178 Training loss: 3.4274 3.3323 sec/batch
Epoch 1/1  Iteration 41/178 Training loss: 3.4208 3.3154 sec/batch
Epoch 1/1  Iteration 42/178 Training loss: 3.4146 3.3034 sec/batch
Epoch 1/1  Iteration 43/178 Training loss: 3.4085 3.2985 sec/batch
Epoch 1/1  Iteration 44/178 Training loss: 3.4028 3.3044 sec/batch
Epoch 1/1  Iteration 45/178 Training loss: 3.3969 3.2885 sec/batch
Epoch 1/1  Iteration 46/178 Training loss: 3.3918 3.3274 sec/batch
Epoch 1/1  Iteration 47/178 Training loss: 3.3869 3.3124 sec/batch
Epoch 1/1  Iteration 48/178 Training loss: 3.3823 3.2925 sec/batch
Epoch 1/1  Iteration 49/178 Training loss: 3.3777 3.3612 sec/batch
Epoch 1/1  Iteration 50/178 Training loss: 3.3734 3.3044 sec/batch
Epoch 1/1  Iteration 51/178 Training loss: 3.3691 3.3443 sec/batch
Epoch 1/1  Iteration 52/178 Training loss: 3.3648 3.3104 sec/batch
Epoch 1/1  Iteration 53/178 Training loss: 3.3610 3.2815 sec/batch
Epoch 1/1  Iteration 54/178 Training loss: 3.3569 3.2695 sec/batch
Epoch 1/1  Iteration 55/178 Training loss: 3.3532 3.2765 sec/batch
Epoch 1/1  Iteration 56/178 Training loss: 3.3492 3.2985 sec/batch
Epoch 1/1  Iteration 57/178 Training loss: 3.3456 3.3300 sec/batch
Epoch 1/1  Iteration 58/178 Training loss: 3.3421 3.3030 sec/batch
Epoch 1/1  Iteration 59/178 Training loss: 3.3385 3.2930 sec/batch
Epoch 1/1  Iteration 60/178 Training loss: 3.3352 3.3090 sec/batch
Epoch 1/1  Iteration 61/178 Training loss: 3.3319 3.3160 sec/batch
Epoch 1/1  Iteration 62/178 Training loss: 3.3291 3.2790 sec/batch
Epoch 1/1  Iteration 63/178 Training loss: 3.3265 3.3220 sec/batch
Epoch 1/1  Iteration 64/178 Training loss: 3.3232 3.3280 sec/batch
Epoch 1/1  Iteration 65/178 Training loss: 3.3202 3.2830 sec/batch
Epoch 1/1  Iteration 66/178 Training loss: 3.3176 3.2770 sec/batch
Epoch 1/1  Iteration 67/178 Training loss: 3.3149 3.3060 sec/batch
Epoch 1/1  Iteration 68/178 Training loss: 3.3116 3.3610 sec/batch
Epoch 1/1  Iteration 69/178 Training loss: 3.3088 3.2890 sec/batch
Epoch 1/1  Iteration 70/178 Training loss: 3.3063 3.4030 sec/batch
Epoch 1/1  Iteration 71/178 Training loss: 3.3036 3.4120 sec/batch
Epoch 1/1  Iteration 72/178 Training loss: 3.3013 3.2880 sec/batch
Epoch 1/1  Iteration 73/178 Training loss: 3.2988 3.3060 sec/batch
Epoch 1/1  Iteration 74/178 Training loss: 3.2965 3.3450 sec/batch
Epoch 1/1  Iteration 75/178 Training loss: 3.2941 3.3190 sec/batch
Epoch 1/1  Iteration 76/178 Training loss: 3.2920 3.3250 sec/batch
Epoch 1/1  Iteration 77/178 Training loss: 3.2898 3.3000 sec/batch
Epoch 1/1  Iteration 78/178 Training loss: 3.2877 3.3340 sec/batch
Epoch 1/1  Iteration 79/178 Training loss: 3.2854 3.3000 sec/batch
Epoch 1/1  Iteration 80/178 Training loss: 3.2829 3.3070 sec/batch
Epoch 1/1  Iteration 81/178 Training loss: 3.2806 3.3200 sec/batch
Epoch 1/1  Iteration 82/178 Training loss: 3.2785 3.4180 sec/batch
Epoch 1/1  Iteration 83/178 Training loss: 3.2765 3.3350 sec/batch
Epoch 1/1  Iteration 84/178 Training loss: 3.2751 3.4171 sec/batch
Epoch 1/1  Iteration 85/178 Training loss: 3.2737 3.6611 sec/batch
Epoch 1/1  Iteration 86/178 Training loss: 3.2718 3.3970 sec/batch
Epoch 1/1  Iteration 87/178 Training loss: 3.2698 3.4000 sec/batch
Epoch 1/1  Iteration 88/178 Training loss: 3.2678 3.4716 sec/batch
Epoch 1/1  Iteration 89/178 Training loss: 3.2661 3.5716 sec/batch
Epoch 1/1  Iteration 90/178 Training loss: 3.2644 3.4877 sec/batch
Epoch 1/1  Iteration 91/178 Training loss: 3.2627 3.6463 sec/batch
Epoch 1/1  Iteration 92/178 Training loss: 3.2610 3.3060 sec/batch
Epoch 1/1  Iteration 93/178 Training loss: 3.2593 3.4187 sec/batch
Epoch 1/1  Iteration 94/178 Training loss: 3.2576 3.5551 sec/batch
Epoch 1/1  Iteration 95/178 Training loss: 3.2558 3.3050 sec/batch
Epoch 1/1  Iteration 96/178 Training loss: 3.2540 3.3290 sec/batch
Epoch 1/1  Iteration 97/178 Training loss: 3.2524 3.3945 sec/batch
Epoch 1/1  Iteration 98/178 Training loss: 3.2507 3.6517 sec/batch
Epoch 1/1  Iteration 99/178 Training loss: 3.2491 3.2800 sec/batch
Epoch 1/1  Iteration 100/178 Training loss: 3.2475 3.3500 sec/batch
Epoch 1/1  Iteration 101/178 Training loss: 3.2459 3.2910 sec/batch
Epoch 1/1  Iteration 102/178 Training loss: 3.2444 3.2940 sec/batch
Epoch 1/1  Iteration 103/178 Training loss: 3.2428 3.3410 sec/batch
Epoch 1/1  Iteration 104/178 Training loss: 3.2412 3.3400 sec/batch
Epoch 1/1  Iteration 105/178 Training loss: 3.2396 3.3180 sec/batch
Epoch 1/1  Iteration 106/178 Training loss: 3.2380 3.3520 sec/batch
Epoch 1/1  Iteration 107/178 Training loss: 3.2362 3.4717 sec/batch
Epoch 1/1  Iteration 108/178 Training loss: 3.2345 3.6021 sec/batch
Epoch 1/1  Iteration 109/178 Training loss: 3.2330 3.2920 sec/batch
Epoch 1/1  Iteration 110/178 Training loss: 3.2311 3.3180 sec/batch
Epoch 1/1  Iteration 111/178 Training loss: 3.2295 3.3370 sec/batch
Epoch 1/1  Iteration 112/178 Training loss: 3.2278 3.2870 sec/batch
Epoch 1/1  Iteration 113/178 Training loss: 3.2260 3.6817 sec/batch
Epoch 1/1  Iteration 114/178 Training loss: 3.2243 3.3380 sec/batch
Epoch 1/1  Iteration 115/178 Training loss: 3.2225 3.3300 sec/batch
Epoch 1/1  Iteration 116/178 Training loss: 3.2208 3.7183 sec/batch
Epoch 1/1  Iteration 117/178 Training loss: 3.2191 3.2839 sec/batch
Epoch 1/1  Iteration 118/178 Training loss: 3.2175 3.2980 sec/batch
Epoch 1/1  Iteration 119/178 Training loss: 3.2159 3.3120 sec/batch
Epoch 1/1  Iteration 120/178 Training loss: 3.2142 3.3050 sec/batch
Epoch 1/1  Iteration 121/178 Training loss: 3.2136 3.3310 sec/batch
Epoch 1/1  Iteration 122/178 Training loss: 3.2133 3.3350 sec/batch
Epoch 1/1  Iteration 123/178 Training loss: 3.2126 3.6867 sec/batch
Epoch 1/1  Iteration 124/178 Training loss: 3.2111 3.4511 sec/batch
Epoch 1/1  Iteration 125/178 Training loss: 3.2093 3.6096 sec/batch
Epoch 1/1  Iteration 126/178 Training loss: 3.2073 3.2950 sec/batch
Epoch 1/1  Iteration 127/178 Training loss: 3.2056 3.3930 sec/batch
Epoch 1/1  Iteration 128/178 Training loss: 3.2039 3.7737 sec/batch
Epoch 1/1  Iteration 129/178 Training loss: 3.2020 3.3776 sec/batch
Epoch 1/1  Iteration 130/178 Training loss: 3.2004 3.6162 sec/batch
Epoch 1/1  Iteration 131/178 Training loss: 3.1992 3.6092 sec/batch
Epoch 1/1  Iteration 132/178 Training loss: 3.1977 3.3620 sec/batch
Epoch 1/1  Iteration 133/178 Training loss: 3.1962 3.3110 sec/batch
Epoch 1/1  Iteration 134/178 Training loss: 3.1944 3.3446 sec/batch
Epoch 1/1  Iteration 135/178 Training loss: 3.1924 3.6677 sec/batch
Epoch 1/1  Iteration 136/178 Training loss: 3.1905 3.2960 sec/batch
Epoch 1/1  Iteration 137/178 Training loss: 3.1888 3.3050 sec/batch
Epoch 1/1  Iteration 138/178 Training loss: 3.1870 3.3230 sec/batch
Epoch 1/1  Iteration 139/178 Training loss: 3.1852 3.7508 sec/batch
Epoch 1/1  Iteration 140/178 Training loss: 3.1834 3.3050 sec/batch
Epoch 1/1  Iteration 141/178 Training loss: 3.1815 3.4044 sec/batch
Epoch 1/1  Iteration 142/178 Training loss: 3.1795 3.7461 sec/batch
Epoch 1/1  Iteration 143/178 Training loss: 3.1775 3.3861 sec/batch
Epoch 1/1  Iteration 144/178 Training loss: 3.1754 3.7082 sec/batch
Epoch 1/1  Iteration 145/178 Training loss: 3.1733 3.3800 sec/batch
Epoch 1/1  Iteration 146/178 Training loss: 3.1713 3.5120 sec/batch
Epoch 1/1  Iteration 147/178 Training loss: 3.1692 3.4720 sec/batch
Epoch 1/1  Iteration 148/178 Training loss: 3.1671 3.3400 sec/batch
Epoch 1/1  Iteration 149/178 Training loss: 3.1649 3.4130 sec/batch
Epoch 1/1  Iteration 150/178 Training loss: 3.1627 3.3680 sec/batch
Epoch 1/1  Iteration 151/178 Training loss: 3.1607 3.4460 sec/batch
Epoch 1/1  Iteration 152/178 Training loss: 3.1587 3.7117 sec/batch
Epoch 1/1  Iteration 153/178 Training loss: 3.1565 3.6850 sec/batch
Epoch 1/1  Iteration 154/178 Training loss: 3.1542 3.3470 sec/batch
Epoch 1/1  Iteration 155/178 Training loss: 3.1518 3.4240 sec/batch
Epoch 1/1  Iteration 156/178 Training loss: 3.1493 3.3510 sec/batch
Epoch 1/1  Iteration 157/178 Training loss: 3.1468 3.3950 sec/batch
Epoch 1/1  Iteration 158/178 Training loss: 3.1443 3.4130 sec/batch
Epoch 1/1  Iteration 159/178 Training loss: 3.1418 3.3690 sec/batch
Epoch 1/1  Iteration 160/178 Training loss: 3.1397 3.4260 sec/batch
Epoch 1/1  Iteration 161/178 Training loss: 3.1376 3.5350 sec/batch
Epoch 1/1  Iteration 162/178 Training loss: 3.1351 3.7147 sec/batch
Epoch 1/1  Iteration 163/178 Training loss: 3.1327 4.0640 sec/batch
Epoch 1/1  Iteration 164/178 Training loss: 3.1304 3.5666 sec/batch
Epoch 1/1  Iteration 165/178 Training loss: 3.1279 3.5182 sec/batch
Epoch 1/1  Iteration 166/178 Training loss: 3.1255 3.7282 sec/batch
Epoch 1/1  Iteration 167/178 Training loss: 3.1231 3.4250 sec/batch
Epoch 1/1  Iteration 168/178 Training loss: 3.1206 3.9319 sec/batch
Epoch 1/1  Iteration 169/178 Training loss: 3.1181 3.6481 sec/batch
Epoch 1/1  Iteration 170/178 Training loss: 3.1155 3.9292 sec/batch
Epoch 1/1  Iteration 171/178 Training loss: 3.1130 3.4711 sec/batch
Epoch 1/1  Iteration 172/178 Training loss: 3.1105 3.8082 sec/batch
Epoch 1/1  Iteration 173/178 Training loss: 3.1082 3.3930 sec/batch
Epoch 1/1  Iteration 174/178 Training loss: 3.1058 3.7346 sec/batch
Epoch 1/1  Iteration 175/178 Training loss: 3.1033 3.4481 sec/batch
Epoch 1/1  Iteration 176/178 Training loss: 3.1006 3.7657 sec/batch
Epoch 1/1  Iteration 177/178 Training loss: 3.0979 3.7508 sec/batch
Epoch 1/1  Iteration 178/178 Training loss: 3.0951 4.0406 sec/batch
Validation loss: 2.52963 Saving checkpoint!

In [15]:
tf.train.get_checkpoint_state('checkpoints/anna')


Out[15]:
model_checkpoint_path: "checkpoints/anna\\i178_l512_2.530.ckpt"
all_model_checkpoint_paths: "checkpoints/anna\\i178_l512_2.530.ckpt"

Sampling

Now that the network is trained, we'll can use it to generate new text. The idea is that we pass in a character, then the network will predict the next character. We can use the new one, to predict the next one. And we keep doing this to generate all new text. I also included some functionality to prime the network with some text by passing in a string and building up a state from that.

The network gives us predictions for each character. To reduce noise and make things a little less random, I'm going to only choose a new character from the top N most likely characters.


In [16]:
def pick_top_n(preds, vocab_size, top_n=5):
    p = np.squeeze(preds)
    p[np.argsort(p)[:-top_n]] = 0
    p = p / np.sum(p)
    c = np.random.choice(vocab_size, 1, p=p)[0]
    return c

In [17]:
def sample(checkpoint, n_samples, lstm_size, vocab_size, prime="The "):
    prime = "Far"
    samples = [c for c in prime]
    model = build_rnn(vocab_size, lstm_size=lstm_size, sampling=True)
    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess, checkpoint)
        new_state = sess.run(model.initial_state)
        for c in prime:
            x = np.zeros((1, 1))
            x[0,0] = vocab_to_int[c]
            feed = {model.inputs: x,
                    model.keep_prob: 1.,
                    model.initial_state: new_state}
            preds, new_state = sess.run([model.preds, model.final_state], 
                                         feed_dict=feed)

        c = pick_top_n(preds, len(vocab))
        samples.append(int_to_vocab[c])

        for i in range(n_samples):
            x[0,0] = c
            feed = {model.inputs: x,
                    model.keep_prob: 1.,
                    model.initial_state: new_state}
            preds, new_state = sess.run([model.preds, model.final_state], 
                                         feed_dict=feed)

            c = pick_top_n(preds, len(vocab))
            samples.append(int_to_vocab[c])
        
    return ''.join(samples)

In [19]:
checkpoint = "checkpoints/anna/i178_l512_2.530.ckpt"
samp = sample(checkpoint, 2000, lstm_size, len(vocab), prime="Far")
print(samp)


Farndd" an oth ho sade whit ion, wh oatentit he an oed tho sorte tan the tot tor otharssens totin she wate tind has his an ade tit aton hil hhar tees aoting hheses hes weris hine wint or al thas and an oo her ate ho ter thit totersse ate oit thos this that
aned oa her anes on hostitin, wo tith and and won itin tet an tit hire sand ton oe han antin in hher ateth his weat ates thot ho hee ade an te hherinn thed sone th tires and hored wos oter thas an hhe te oo tas ithes, an aneter in ansting han aris ian he ile te hatan tin this the sito hes woat hil oote serto hos tote her ansinn otis hetetess hos torin ad whe han at ootite asd te thine te shes tale tot hes he sond oo tent ing terersint hhe satil al an thiles tal an an al aleded ton aneses oat th al ante ton and oting
atit he tan and on at oot and on tes arind at atorid sot on tas hese ad ten otorerte he aring arte tirog hot war tal thes and ter hite so hand tee atin herthe the hor and hasd inn ite atha th oo this ate sesand wad tant te al thaser ois in tareth the ha tin the thee sher on arerso it oe sentares oal ho te tires
in ion hon asd whed ther hhes
ine thed sose ant oete ter thesed on ate hased to ton tot hare sor hosid, tiset oot hhe ar oin than, ar hed ond who har aod ane tor oot al hers aon he worte he tor tor aro ase hine an aleting ate hed aserte ant ating tat oter on tas an ot he sas oid tetit oe sote ote shar teer oad wo the shad the se ore othi antetering aod thin woned wan in tas aled ant oo hire ant hhe than and tood thit ate anethe as tote ton san tho tes arit ang hot he atilg tote an thhar anthe arditit herins woa hot hoe wit oon he as tot an oit her ates on ath an toore as he sate his ot or at he to tas teed
sin ang an ithan ard alt oo as ites it tot on shesesenn hat ond hhe woad the athe tartand the thar ao he an an anin he terisinn ad tho sarte an thas as toed ton ote artand and wh ithe ares oo the ar oont at in at tar heed sot ion sont he wore he tant ao hin hete one and an hir aotinn thet ant ar i

In [43]:
checkpoint = "checkpoints/anna/i200_l512_2.432.ckpt"
samp = sample(checkpoint, 1000, lstm_size, len(vocab), prime="Far")
print(samp)


Farnt him oste wha sorind thans tout thint asd an sesand an hires on thime sind thit aled, ban thand and out hore as the ter hos ton ho te that, was tis tart al the hand sostint him sore an tit an son thes, win he se ther san ther hher tas tarereng,.

Anl at an ades in ond hesiln, ad hhe torers teans, wast tar arering tho this sos alten sorer has hhas an siton ther him he had sin he ard ate te anling the sosin her ans and
arins asd and ther ale te tot an tand tanginge wath and ho ald, so sot th asend sat hare sother horesinnd, he hesense wing ante her so tith tir sherinn, anded and to the toul anderin he sorit he torsith she se atere an ting ot hand and thit hhe so the te wile har
ens ont in the sersise, and we he seres tar aterer, to ato tat or has he he wan ton here won and sen heren he sosering, to to theer oo adent har herere the wosh oute, was serild ward tous hed astend..

I's sint on alt in har tor tit her asd hade shithans ored he talereng an soredendere tim tot hees. Tise sor and 

In [46]:
checkpoint = "checkpoints/anna/i600_l512_1.750.ckpt"
samp = sample(checkpoint, 1000, lstm_size, len(vocab), prime="Far")
print(samp)


Fard as astice her said he celatice of to seress in the raice, and to be the some and sere allats to that said to that the sark and a cast a the wither ald the pacinesse of her had astition, he said to the sount as she west at hissele. Af the cond it he was a fact onthis astisarianing.


"Or a ton to to be that's a more at aspestale as the sont of anstiring as
thours and trey.

The same wo dangring the
raterst, who sore and somethy had ast out an of his book. "We had's beane were that, and a morted a thay he had to tere. Then to
her homent andertersed his his ancouted to the pirsted, the soution for of the pirsice inthirgest and stenciol, with the hard and and
a colrice of to be oneres,
the song to this anderssad.
The could ounterss the said to serom of
soment a carsed of sheres of she
torded
har and want in their of hould, but
her told in that in he tad a the same to her. Serghing an her has and with the seed, and the camt ont his about of the
sail, the her then all houg ant or to hus to 

In [47]:
checkpoint = "checkpoints/anna/i1000_l512_1.484.ckpt"
samp = sample(checkpoint, 1000, lstm_size, len(vocab), prime="Far")
print(samp)


Farrat, his felt has at it.

"When the pose ther hor exceed
to his sheant was," weat a sime of his sounsed. The coment and the facily that which had began terede a marilicaly whice whether the pose of his hand, at she was alligated herself the same on she had to
taiking to his forthing and streath how to hand
began in a lang at some at it, this he cholded not set all her. "Wo love that is setthing. Him anstering as seen that."

"Yes in the man that say the mare a crances is it?" said Sergazy Ivancatching. "You doon think were somether is ifficult of a mone of
though the most at the countes that the
mean on the come to say the most, to
his feesing of
a man she, whilo he
sained and well, that he would still at to said. He wind at his for the sore in the most
of hoss and almoved to see him. They have betine the sumper into at he his stire, and what he was that at the so steate of the
sound, and shin should have a geest of shall feet on the conderation to she had been at that imporsing the dre