RECURRENT NETWORKS and LSTM IN DEEP LEARNING

Applying Recurrent Neural Networks/LSTM for Language Modelling

Hello and welcome to this part. In this notebook, we will go over the topic of what Language Modelling is and create a Recurrent Neural Network model based on the Long Short-Term Memory unit to train and be benchmarked by the Penn Treebank. By the end of this notebook, you should be able to understand how TensorFlow builds and executes a RNN model for Language Modelling.


The Objective

By now, you should have an understanding of how Recurrent Networks work -- a specialized model to process sequential data by keeping track of the "state" or context. In this notebook, we go over a TensorFlow code snippet for creating a model focused on Language Modelling -- a very relevant task that is the cornerstone of many different linguistic problems such as Speech Recognition, Machine Translation and Image Captioning. For this, we will be using the Penn Treebank, which is an often-used dataset for benchmarking Language Modelling models.

What exactly is Language Modelling?

Language Modelling, to put it simply, is the task of assigning probabilities to sequences of words. This means that, given a context of one or a few words in the language the model was trained on, the model should have a knowledge of what are the most probable words or sequence of words for the sentence. Language Modelling is one of the tasks under Natural Language Processing, and one of the most important.

*Example of a sentence being predicted*

In this example, one can see the predictions for the next word of a sentence, given the context "This is an". As you can see, this boils down to a sequential data analysis task -- you are given a word or a sequence of words (the input data), and, given the context (the state), you need to find out what is the next word (the prediction). This kind of analysis is very important for language-related tasks such as Speech Recognition, Machine Translation, Image Captioning, Text Correction and many other very relevant problems.

*The above example schematized as an RNN in execution*

As the above image shows, Recurrent Network models fit this problem like a glove. Alongside LSTM and its capacity to maintain the model's state for over one thousand time steps, we have all the tools we need to undertake this problem. The goal for this notebook is to create a model that can reach low levels of perplexity on our desired dataset.

For Language Modelling problems, perplexity is the way to gauge efficiency. Perplexity is simply a measure of how well a probabilistic model is able to predict its sample. A higher-level way to explain this would be saying that low perplexity means a higher degree of trust in the predictions the model makes. Therefore, the lower perplexity is, the better.

The Penn Treebank dataset

Historically, datasets big enough for Natural Language Processing are hard to come by. This is in part due to the necessity of the sentences to be broken down and tagged with a certain degree of correctness -- or else the models trained on it won't be able to be correct at all. This means that we need a large amount of data, annotated by or at least corrected by humans. This is, of course, not an easy task at all.

The Penn Treebank, or PTB for short, is a dataset maintained by the University of Pennsylvania. It is huge -- there are over four million and eight hundred thousand annotated words in it, all corrected by humans. It is composed of many different sources, from abstracts of Department of Energy papers to texts from the Library of America. Since it is verifiably correct and of such a huge size, the Penn Treebank has been used time and time again as a benchmark dataset for Language Modelling.

The dataset is divided in different kinds of annotations, such as Piece-of-Speech, Syntactic and Semantic skeletons. For this example, we will simply use a sample of clean, non-annotated words (with the exception of one tag -- <unk>, which is used for rare words such as uncommon proper nouns) for our model. This means that we just want to predict what the next words would be, not what they mean in context or their classes on a given sentence.

the percentage of lung cancer deaths among the workers at the west `` mass. paper factory appears to be the highest for any asbestos workers studied in western industrialized countries he said the plant which is owned by `` & `` co. was under contract with `` to make the cigarette filters the finding probably will support those who argue that the u.s. should regulate the class of asbestos including `` more `` than the common kind of asbestos `` found in most schools and other buildings dr. `` said
</div>
*Example of text from the dataset we are going to use, `ptb.train`*

Word Embeddings


For better processing, in this example, we will make use of word embeddings, which are a way of representing sentence structures or words as n-dimensional vectors (where n is a reasonably high number, such as 200 or 500) of real numbers. Basically, we will assign each word a randomly-initialized vector, and input those into the network to be processed. After a number of iterations, these vectors are expected to assume values that help the network to correctly predict what it needs to -- in our case, the probable next word in the sentence. This is shown to be very effective in Natural Language Processing tasks, and is a commonplace practice.

$$Vec("Example") = [0.02, 0.00, 0.00, 0.92, 0.30,...]$$ </strong>
Word Embedding tends to group up similarly used words reasonably together in the vectorial space. For example, if we use T-SNE (a dimensional reduction visualization algorithm) to flatten the dimensions of our vectors into a 2-dimensional space and use the words these vectors represent as their labels, we might see something like this:

*T-SNE Mockup with clusters marked for easier visualization*

As you can see, words that are frequently used together, in place of each other, or in the same places as them tend to be grouped together -- being closer together the higher these correlations are. For example, "None" is pretty semantically close to "Zero", while a phrase that uses "Italy" can probably also fit "Germany" in it, with little damage to the sentence structure. A vectorial "closeness" for similar words like this is a great indicator of a well-built model.


We need to import the necessary modules for our code. We need numpy and tensorflow, obviously. Additionally, we can import directly the tensorflow.models.rnn.rnn model, which includes the function for building RNNs, and tensorflow.models.rnn.ptb.reader which is the helper module for getting the input data from the dataset we just downloaded.

If you want to learm more take a look at https://www.tensorflow.org/versions/r0.11/api_docs/python/rnn_cell/


In [1]:
import time
import numpy as np
import tensorflow as tf
import os

print('TensorFlow version: ', tf.__version__)


TensorFlow version:  1.1.0

In [2]:
tf.reset_default_graph()

In [3]:
if not os.path.isfile('./penn_treebank_reader.py'):
    print('Downloading penn_treebank_reader.py...')
    !wget -q -O ../../data/Penn_Treebank/ptb.zip https://ibm.box.com/shared/static/z2yvmhbskc45xd2a9a4kkn6hg4g4kj5r.zip
    !unzip -o ../../data/Penn_Treebank/ptb.zip -d ../data/Penn_Treebank
    !cp ../../data/Penn_Treebank/ptb/reader.py ./penn_treebank_reader.py
else:
    print('Using local penn_treebank_reader.py...')


Using local penn_treebank_reader.py...

In [4]:
import penn_treebank_reader as reader

Building the LSTM model for Language Modeling

Now that we know exactly what we are doing, we can start building our model using TensorFlow. The very first thing we need to do is download and extract the simple-examples dataset, which can be done by executing the code cell below.


In [5]:
if not os.path.isfile('../data/Penn_Treebank/simple_examples.tgz'):
    !wget -O ../../data/Penn_Treebank/simple_examples.tgz http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz 
    !tar xzf ../../data/Penn_Treebank/simple_examples.tgz -C ../data/Penn_Treebank/


--2017-05-23 10:45:45--  http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
Resolving www.fit.vutbr.cz (www.fit.vutbr.cz)... 147.229.9.23, 2001:67c:1220:809::93e5:917
Connecting to www.fit.vutbr.cz (www.fit.vutbr.cz)|147.229.9.23|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 34869662 (33M) [application/x-gtar]
Saving to: ‘../../data/Penn_Treebank/simple_examples.tgz’

../../data/Penn_Tre 100%[===================>]  33,25M  3,79MB/s    in 9,1s    

2017-05-23 10:45:55 (3,67 MB/s) - ‘../../data/Penn_Treebank/simple_examples.tgz’ saved [34869662/34869662]

tar: ../data/Penn_Treebank: Cannot open: No such file or directory
tar: Error is not recoverable: exiting now

Additionally, for the sake of making it easy to play around with the model's hyperparameters, we can declare them beforehand. Feel free to change these -- you will see a difference in performance each time you change those!


In [9]:
#Initial weight scale
init_scale = 0.1
#Initial learning rate
learning_rate = 1.0
#Maximum permissible norm for the gradient (For gradient clipping -- another measure against Exploding Gradients)
max_grad_norm = 5
#The number of layers in our model
num_layers = 2
#The total number of recurrence steps, also known as the number of layers when our RNN is "unfolded"
num_steps = 20
#The number of processing units (neurons) in the hidden layers
hidden_size = 200
#The maximum number of epochs trained with the initial learning rate
max_epoch = 4
#The total number of epochs in training
max_max_epoch = 13
#The probability for keeping data in the Dropout Layer (This is an optimization, but is outside our scope for this notebook!)
#At 1, we ignore the Dropout Layer wrapping.
keep_prob = 1
#The decay for the learning rate
decay = 0.5
#The size for each batch of data
batch_size = 30
#The size of our vocabulary
vocab_size = 10000
#Training flag to separate training from testing
is_training = 1
#Data directory for our dataset
data_dir = "../../data/Penn_Treebank/simple-examples/data/"

Some clarifications for LSTM architecture based on the argumants:

Network structure:

  • In this network, the number of LSTM cells are 2. To give the model more expressive power, we can add multiple layers of LSTMs to process the data. The output of the first layer will become the input of the second and so on.
  • The recurrence steps is 20, that is, when our RNN is "Unfolded", the recurrence step is 20.
  • the structure is like:
    • 200 input units -> [200x200] Weight -> 200 Hidden units (first layer) -> [200x200] Weight matrix -> 200 Hidden units (second layer) -> [200] weight Matrix -> 200 unit output

Hidden layer:

  • Each LSTM has 200 hidden units which is equivalant to the dimensianality of the embedding words and output.

Input layer:

  • The network has 200 input units.
  • Suppose each word is represented by an embedding vector of dimensionality e=200. The input layer of each cell will have 200 linear units. These e=200 linear units are connected to each of the h=200 LSTM units in the hidden layer (assuming there is only one hidden layer, though our case has 2 layers).
  • The input shape is [batch_size, num_steps], that is [30x20]. It will turn into [30x20x200] after embedding, and then 20x[30x200]

There is a lot to be done and a ton of information to process at the same time, so go over this code slowly. It may seem complex at first, but if you try to ally what you just learned about language modelling to the code you see, you should be able to understand it.

This code is adapted from the PTBModel example bundled with the TensorFlow source code.

Train data

The story starts from data:

  • Train data is a list of words, represented by numbers - N=929589 numbers, e.g. [9971, 9972, 9974, 9975,...]
  • We read data as mini-batch of size b=30. Assume the size of each sentence is 20 words (num_steps = 20). Then it will take int(N/b*h)+1=1548 iterations for the learner to go through all sentences once. So, the number of iterators is 1548
  • Each batch data is read from train dataset of size 600, and shape of [30x20]

First we start an interactive session:


In [10]:
session=tf.InteractiveSession()

In [11]:
# Reads the data and separates it into training data, validation data and testing data
raw_data = reader.ptb_raw_data(data_dir)
train_data, valid_data, test_data, _ = raw_data

Lets just read one mini-batch now and feed our network:


In [12]:
itera = reader.ptb_iterator(train_data, batch_size, num_steps)
first_touple=next(itera)
x=first_touple[0]
y=first_touple[1]

In [13]:
x.shape


Out[13]:
(30, 20)

Lets look at 3 sentences of our input x:


In [14]:
x[0:3]


Out[14]:
array([[9970, 9971, 9972, 9974, 9975, 9976, 9980, 9981, 9982, 9983, 9984,
        9986, 9987, 9988, 9989, 9991, 9992, 9993, 9994, 9995],
       [2654,    6,  334, 2886,    4,    1,  233,  711,  834,   11,  130,
         123,    7,  514,    2,   63,   10,  514,    8,  605],
       [   0, 1071,    4,    0,  185,   24,  368,   20,   31, 3109,  954,
          12,    3,   21,    2, 2915,    2,   12,    3,   21]], dtype=int32)

In [15]:
size = hidden_size

we define 2 place holders to feed them with mini-batchs, that is x and y:


In [16]:
_input_data = tf.placeholder(tf.int32, [batch_size, num_steps]) #[30#20]
_targets = tf.placeholder(tf.int32, [batch_size, num_steps]) #[30#20]

lets defin a dictionary, and use it later to feed the placeholders with our first mini-batch:


In [17]:
feed_dict={_input_data:x, _targets:y}

For example, we can use it to feed _input_data:


In [18]:
session.run(_input_data,feed_dict)


Out[18]:
array([[9970, 9971, 9972, 9974, 9975, 9976, 9980, 9981, 9982, 9983, 9984,
        9986, 9987, 9988, 9989, 9991, 9992, 9993, 9994, 9995],
       [2654,    6,  334, 2886,    4,    1,  233,  711,  834,   11,  130,
         123,    7,  514,    2,   63,   10,  514,    8,  605],
       [   0, 1071,    4,    0,  185,   24,  368,   20,   31, 3109,  954,
          12,    3,   21,    2, 2915,    2,   12,    3,   21],
       [   3,   71,    4,   27,  246,   60,   11,  215,    4,    1, 1846,
           9,    3,   71,  546,    2, 6505,  162,    6,  104],
       [  93,   25,    6,  261,  681,  251,    0,  278, 3246,   13,  200,
           1,    8,  105, 3360,    1,    4,    0,  536,    4],
       [  20,    6,  954,   12,    3,   21,   78,   14,  977,  726,    0,
          37,   42,   34,    5,  437,  116,  206,  927,    2],
       [  18,  296,    7,  201,   76,    4,  182,  560, 3836,   17,  974,
         975,    6,  942,    4,  156,   77, 1570,  288,  644],
       [  23, 1238,  899,    5,   25,  201,    4,    0,  434,  642,   55,
         201,    4,    0, 2423,    2,    1,    1,    1,  483],
       [ 379,  706,    9,  413, 8219,   96,   15,    0, 2185, 1758,    1,
           1,   37,   13,  834,    5,  852,  222,    7, 1785],
       [   2,  179,  940,  117,   38,   59,  677,   14,    1,   10, 1016,
         309,   13, 1077, 6360,   16,   23, 4490,    9,  355],
       [3572,    4, 3015, 1347,  536,   13,    6, 3949,    5,  438, 9643,
           2,   64,   87,   32,  358, 3672, 4103, 1082,   11],
       [  71,  178,    3,    8,    3,    2,    0, 1008,  234,   30, 6400,
          10,    0,   98,    9,    1,  338,   13,    5,   25],
       [1473,   88,   19, 2578, 6591,    8,  629,  563,    8,  223,  184,
         127,   18,    6,  828,    1,    2,    0,  324,  158],
       [   1,    1,    2,   18,    0, 1844,    4,   73,   39, 2694,    6,
        1709,    2,    7,    0, 6509, 1116,   27,    1,    1],
       [1055,    5,   25, 8582,   10,  353,  645,   24,    6,  287,    2,
        1006,    0, 8861, 2369,   44,    7,    0,    1,  180],
       [  36,  501,    5,    6, 1969,    0,   98,   89, 2254,    0,  312,
        1641,    4, 1063,    8,  713,    0,  264,  820,    2],
       [  32, 2599,  762, 1875,   26, 1402,   45,  516,    2, 2937,   16,
        3355, 2062,  251,    0,  529,   24, 1625,  122,   18],
       [ 677,  127,    2,   19,   23, 7800, 3592,   14,   64,   87,   32,
         350,    0, 3968,    2,   38,   26,  114,   38,   26],
       [  25,   45,  769,    2,   23, 2634, 1096, 1175,   19,    6,    1,
         154,   23, 1890,   30,    6,    1,    1,    2,  198],
       [7736,  391,    5, 5173,  838,    2,  840,    9, 8716,  537, 4132,
        2915,    9,    1,    1,   10, 1268,  175,   32,  184],
       [   3,   21,    4,    1,  308,  458,   11,   41,   14, 5718,  102,
         824,    1,    2,   14,   59,   50,   12,    3,   21],
       [   8,    1,   22,   73,   10,  863,   11,  898,  653,  270,    8,
         500,  273, 1559,    2,   14, 3019,    5,  585,   84],
       [ 483,  762,   87,  108, 1119,    0,    1,   67,    0, 3296,   26,
         591,  174,  127,    2,  108,   26, 9821,   11,    6],
       [3885,  582,   81,   17, 1834,    2, 1256,   98,  162,  582,  441,
         125,   22, 1652,  172,    4,    3,    3,    8,  206],
       [  44,   23,    1,    0, 1704,    4,    1,    2,   22,  373,   38,
         275,    1, 8017,    2, 2785, 3659, 4359,   80,  634],
       [1896,    8,   13, 9468,   17,  752, 4622,    2,   29, 2221,    0,
         446, 3552,    4,    0, 2495,  431,  134,  284,  152],
       [  48,    7, 1741,  193,    8,  446,  165,  301, 6521, 5122,   15,
          12,    3,   21,    4,   10,  161,  783,    8,   79],
       [  47, 4447, 1431,    4, 6967, 2121,   24,  452,   18,   43,    3,
          48, 1076,   12,    3,   21,   69,   40,    2, 1323],
       [  31, 3374,    4, 2108,    1,  134,    8, 6967, 1825, 3306,   14,
          13, 3581,    5, 2424, 1583, 6495,    5,    6, 1136],
       [  59, 2070, 2433,   28,  517,   20,   23, 4306,    6,   40,  195,
           2, 9398,  400, 4908,  673, 1572,  400,    1, 1173]], dtype=int32)

In this step, we create the stacked LSTM, which is a 2 layer LSTM network:


In [19]:
stacked_lstm = tf.contrib.rnn.MultiRNNCell([tf.contrib.rnn.BasicLSTMCell(hidden_size, forget_bias=0.0) 
                                            for _ in range(num_layers)]
                                          )

Also, we initialize the states of the nework:

_initial_state

For each LCTM, there are 2 state matrics, c_state and m_state. c_state and m_state represent "Memory State" and "Cell State". Each hidden layer, has a vector of size 30, which keeps the states. so, for 200 hidden units in each LSTM, we have a matrix of size [30x200]


In [20]:
_initial_state = stacked_lstm.zero_state(batch_size, tf.float32)
_initial_state


Out[20]:
(LSTMStateTuple(c=<tf.Tensor 'MultiRNNCellZeroState/BasicLSTMCellZeroState/zeros:0' shape=(30, 200) dtype=float32>, h=<tf.Tensor 'MultiRNNCellZeroState/BasicLSTMCellZeroState/zeros_1:0' shape=(30, 200) dtype=float32>),
 LSTMStateTuple(c=<tf.Tensor 'MultiRNNCellZeroState/BasicLSTMCellZeroState_1/zeros:0' shape=(30, 200) dtype=float32>, h=<tf.Tensor 'MultiRNNCellZeroState/BasicLSTMCellZeroState_1/zeros_1:0' shape=(30, 200) dtype=float32>))

lets look at the states, though they are all zero for now:


In [21]:
session.run(_initial_state,feed_dict)


Out[21]:
(LSTMStateTuple(c=array([[ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       ..., 
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.]], dtype=float32), h=array([[ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       ..., 
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.]], dtype=float32)),
 LSTMStateTuple(c=array([[ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       ..., 
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.]], dtype=float32), h=array([[ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       ..., 
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.]], dtype=float32)))

Embeddings

We create the embeddings for our input data. embedding is dictionary of [10000x200] for all 10000 unique words.


In [22]:
try:
    embedding = tf.get_variable("embedding", [vocab_size, hidden_size])  #[10000x200]
except ValueError:
    pass
embedding.get_shape().as_list()


Out[22]:
[10000, 200]

In [23]:
session.run(tf.global_variables_initializer())
session.run(embedding, feed_dict)


Out[23]:
array([[ 0.00419833, -0.011482  , -0.01549463, ...,  0.01950766,
         0.01832177, -0.02384475],
       [ 0.02171985,  0.00561222,  0.01664338, ...,  0.0237095 ,
        -0.01921788, -0.01106183],
       [ 0.01881259,  0.02292644,  0.00497515, ...,  0.01437099,
         0.01298326, -0.02087402],
       ..., 
       [ 0.01548864,  0.017141  , -0.02069814, ..., -0.01426928,
         0.00501703, -0.02374811],
       [ 0.02391572,  0.01972473, -0.01133233, ...,  0.01748331,
         0.01228622,  0.01979224],
       [ 0.00684312, -0.00510194,  0.01945009, ..., -0.00847306,
         0.02030636, -0.0018806 ]], dtype=float32)

embedding_lookup goes to each row of input_data, and for each word in the row/sentence, finds the correspond vector in embedding. It creates a [3020200] matrix, so, the first elemnt of inputs (the first sentence), is a matrix of 20x200, which each row of it is vector representing a word in the sentence.


In [24]:
# Define where to get the data for our embeddings from
inputs = tf.nn.embedding_lookup(embedding, _input_data)  #shape=(30, 20, 200)

In [25]:
inputs


Out[25]:
<tf.Tensor 'embedding_lookup:0' shape=(30, 20, 200) dtype=float32>

In [26]:
session.run(inputs[0], feed_dict)


Out[26]:
array([[-0.01766763, -0.0102601 , -0.00787331, ..., -0.0012469 ,
        -0.01207721, -0.01660829],
       [ 0.02301355,  0.02373759, -0.01453253, ..., -0.02158061,
         0.01211887, -0.01520211],
       [-0.00116363,  0.01608609,  0.00544146, ..., -0.00878267,
         0.00452347,  0.00519271],
       ..., 
       [ 0.01150502,  0.00421331, -0.0161419 , ...,  0.0184783 ,
         0.01090024, -0.01899387],
       [ 0.00922056,  0.01489506, -0.01159351, ..., -0.00680826,
        -0.0197531 ,  0.01292341],
       [ 0.01657515, -0.00247201,  0.02124329, ...,  0.01656852,
        -0.01288021,  0.00498772]], dtype=float32)

Constructing Recurrent Neural Networks

tf.nn.dynamicrnn() creates a recurrent neural network using stacked_lstm which is an instance of RNNCell.

The input should be a Tensor of shape: [batch_size, max_time, ...], in our case it would be (30, 20, 200)

This method, returns a pair (outputs, new_state) where:

  • outputs is a length T list of outputs (one for each input), or a nested tuple of such elements.
  • new_state is the final state

In [27]:
outputs, new_state =  tf.nn.dynamic_rnn(stacked_lstm, inputs, initial_state=_initial_state)

so, lets look at the outputs. The output of the stackedLSTM comes from 200 hidden_layer, and in each time step(=20), one of them get activated. we use the linear activation to map the 200 hidden layer to a [?x10 matrix]


In [28]:
outputs


Out[28]:
<tf.Tensor 'rnn/transpose:0' shape=(30, 20, 200) dtype=float32>

In [29]:
session.run(tf.global_variables_initializer())
session.run(outputs[0], feed_dict)


Out[29]:
array([[  2.31600206e-04,  -9.61103651e-05,  -1.28215936e-04, ...,
         -1.00696590e-04,  -1.18866556e-04,   3.85970998e-06],
       [  5.98545128e-04,  -1.57938353e-04,   6.04075285e-05, ...,
          7.42806660e-05,  -4.91725732e-05,   8.44031911e-06],
       [  4.31441877e-04,  -7.12473193e-05,  -2.78066946e-05, ...,
          4.59898321e-04,   2.40637557e-04,  -6.21227009e-05],
       ..., 
       [  4.11826361e-04,  -3.25875997e-04,   2.63570619e-05, ...,
         -5.27737895e-04,   4.79726616e-04,   2.65730941e-05],
       [  6.67331857e-04,  -5.47913078e-04,  -2.34936088e-04, ...,
         -7.87418219e-04,   1.08516018e-03,   4.59716512e-06],
       [  4.38249495e-04,  -5.56082989e-04,  -7.99677218e-04, ...,
         -1.04818272e-03,   1.60882762e-03,  -3.46088491e-04]], dtype=float32)

Lets reshape the output tensor from [30 x 20 x 200] to [600 x 200]


In [30]:
output = tf.reshape(outputs, [-1, size])
output


Out[30]:
<tf.Tensor 'Reshape:0' shape=(600, 200) dtype=float32>

In [31]:
session.run(output[0], feed_dict)


Out[31]:
array([  2.31600206e-04,  -9.61103651e-05,  -1.28215936e-04,
         1.82282485e-04,  -1.81498559e-04,   7.17349540e-05,
         2.64149548e-05,   1.84381948e-04,  -1.61701348e-04,
         5.18249755e-04,  -2.19192108e-04,   4.76456655e-04,
        -2.41356611e-04,   7.81506897e-05,   5.45248622e-04,
         1.34541493e-04,  -1.13408685e-04,   3.70240392e-04,
         2.36539927e-04,  -3.05564376e-04,   5.46291849e-05,
        -2.01269737e-04,  -4.69016639e-04,   5.44591494e-05,
         4.91849962e-04,   2.42963564e-04,  -3.79720383e-04,
         5.38949971e-04,  -3.36418016e-04,  -2.48538036e-05,
        -4.22927173e-04,   1.71456923e-04,   3.96906282e-04,
         3.13919358e-04,   6.60799269e-05,  -6.45390013e-04,
         3.86007683e-04,  -8.17814653e-05,  -2.69292592e-04,
        -1.25197927e-04,  -2.99289677e-04,  -3.30634910e-04,
        -2.20897768e-04,  -1.60990108e-04,  -4.80899558e-04,
         9.81728808e-05,   5.63230560e-06,   1.78039350e-04,
        -1.75673253e-04,   1.52502462e-05,   3.18863691e-04,
        -2.14501692e-04,  -3.85752413e-04,  -2.60850298e-04,
         2.00806389e-04,  -6.49215581e-05,  -5.33773855e-04,
        -2.78106072e-05,  -1.33143025e-04,   4.86028621e-05,
         8.62033412e-05,   2.16736502e-04,  -1.68910803e-04,
         8.07778633e-05,  -3.61343300e-05,  -3.78177210e-04,
        -1.76992908e-05,  -1.91219646e-04,   2.73561949e-04,
         1.02845159e-04,   3.10430823e-06,   1.20603174e-04,
         2.87191273e-04,  -1.31608685e-04,  -2.39275774e-04,
        -1.09928951e-04,   3.52104980e-04,   2.39414992e-04,
         7.14096939e-04,  -1.60532800e-04,  -3.09004536e-04,
        -1.65169709e-04,   1.61329706e-06,   2.65004375e-04,
         2.45452713e-04,   1.77102411e-04,   4.51763684e-04,
         2.24396339e-04,  -2.38765584e-04,   4.64103323e-05,
        -2.46514071e-04,   1.77763781e-04,   2.63494236e-04,
         3.96781164e-04,   7.36463044e-05,   2.74454243e-04,
        -1.74387227e-04,  -5.63693211e-05,  -2.66648203e-06,
         4.65694407e-04,  -8.88080649e-06,  -1.53669913e-04,
         1.05132909e-04,   2.04267315e-04,  -1.34888396e-04,
        -5.24983261e-05,   3.10179857e-05,   1.20539007e-04,
         3.15797224e-04,  -5.34937717e-04,   1.87294630e-04,
         2.88283889e-04,   4.08066291e-04,  -1.97968020e-05,
         8.61582794e-06,   9.71457193e-05,  -9.95817536e-05,
         5.78351246e-05,   3.55252065e-04,  -9.65764630e-05,
        -1.26164465e-04,  -2.44366849e-04,  -7.49962692e-06,
         5.04637661e-04,  -2.79218133e-04,   1.21676065e-04,
        -2.53711478e-04,  -4.42211225e-04,  -1.20862744e-04,
         3.65001237e-04,   3.59707075e-04,   3.23371496e-04,
         2.31745755e-04,   2.06724668e-04,  -8.08368131e-05,
         9.10271556e-05,   1.46002887e-04,  -2.12358951e-04,
         5.98551414e-04,  -8.54013706e-06,   2.13577761e-04,
         2.43875111e-04,   1.32346235e-04,  -4.82845149e-04,
         1.83731652e-04,  -1.35736089e-04,  -1.54302747e-04,
         1.82079151e-04,  -2.82852998e-04,  -1.54735390e-04,
         1.20612611e-04,   2.27278011e-04,  -1.97871686e-05,
        -5.06633136e-04,   6.01693697e-04,  -4.35120410e-06,
        -5.11175895e-05,  -3.73732510e-05,   8.23505907e-05,
        -5.41476184e-04,  -3.21346422e-04,   1.80378393e-05,
        -7.24676556e-06,   4.34288158e-05,   1.14305323e-04,
        -2.66879390e-04,   1.45480253e-05,  -5.11497819e-05,
        -3.48389207e-04,  -4.56573587e-04,  -1.03762657e-04,
         4.64261393e-04,  -6.93021429e-05,  -4.47534345e-04,
         3.83870763e-04,   1.47205923e-04,   3.53405339e-05,
         3.27963403e-06,  -7.10431996e-05,   1.58949711e-04,
        -1.12677426e-05,   4.93735890e-04,  -2.21638111e-04,
        -9.81052945e-05,  -1.41739161e-04,   1.91019808e-05,
         4.29605461e-05,  -1.89026308e-04,  -2.95807346e-04,
        -5.42630209e-04,   3.80970247e-04,  -5.20867412e-04,
         3.15613463e-04,   4.93762600e-05,   5.76188067e-05,
        -2.34484018e-04,  -8.10405400e-05,  -1.00696590e-04,
        -1.18866556e-04,   3.85970998e-06], dtype=float32)

logistic unit

Now, we create a logistic unit to return the probability of the output word. That is, mapping the 600

Softmax = [600 x 200]* [200 x 1000]+ [1 x 1000] -> [600 x 1000]


In [32]:
softmax_w = tf.get_variable("softmax_w", [size, vocab_size]) #[200x1000]
softmax_b = tf.get_variable("softmax_b", [vocab_size]) #[1x1000]
logits = tf.matmul(output, softmax_w) + softmax_b

In [33]:
session.run(tf.global_variables_initializer())
logi = session.run(logits, feed_dict)
logi.shape


Out[33]:
(600, 10000)

In [34]:
First_word_output_probablity = logi[0]
First_word_output_probablity.shape


Out[34]:
(10000,)

Prediction

The maximum probablity


In [35]:
embedding_array= session.run(embedding, feed_dict)
np.argmax(First_word_output_probablity)


Out[35]:
8647

So, what is the ground truth for the first word of first sentence?


In [36]:
y[0][0]


Out[36]:
9971

Also, you can get it from target tensor, if you want to find the embedding vector:


In [37]:
_targets


Out[37]:
<tf.Tensor 'Placeholder_1:0' shape=(30, 20) dtype=int32>

It is time to compare logit with target


In [38]:
targ = session.run(tf.reshape(_targets, [-1]), feed_dict)

In [39]:
first_word_target_code= targ[0]
first_word_target_code


Out[39]:
9971

In [40]:
first_word_target_vec = session.run( tf.nn.embedding_lookup(embedding, targ[0]))
first_word_target_vec


Out[40]:
array([  1.86736062e-02,   1.61858462e-02,  -1.40998391e-02,
        -1.36882123e-02,  -1.33983875e-02,   2.05998421e-02,
         2.46733427e-04,  -1.95281785e-02,  -1.66366622e-02,
         2.35252567e-02,  -1.56327747e-02,  -1.79377422e-02,
        -1.52281262e-02,   2.24282108e-02,  -2.25153193e-02,
         1.12179555e-02,   1.13971643e-02,   2.04268061e-02,
        -1.63705237e-02,  -1.23975612e-03,  -1.90807115e-02,
         2.20058970e-02,   1.95904449e-02,  -6.56226464e-03,
        -1.81781966e-02,  -1.43649057e-04,   9.90456715e-03,
         6.35352731e-03,  -1.19924247e-02,   1.18489973e-02,
         2.35010795e-02,   4.13514674e-03,  -9.61622782e-03,
         9.97390598e-03,  -7.55315647e-04,  -1.23191299e-02,
         1.46972574e-02,  -1.02031920e-02,   2.31070556e-02,
         1.71499401e-02,   8.16120580e-03,   8.09521228e-03,
        -1.05001172e-02,  -2.01561172e-02,  -8.07870179e-03,
         5.75733930e-03,  -1.35746440e-02,   9.17498395e-03,
         5.88469952e-03,   6.91398419e-03,  -2.21318193e-02,
         4.06277366e-03,   9.75106284e-03,  -2.21701339e-04,
        -1.55733824e-02,   9.18336213e-04,  -1.30402660e-02,
        -2.35606972e-02,  -1.80996414e-02,  -1.20734898e-02,
         1.06933340e-02,   7.77347758e-03,   1.42692029e-02,
        -1.79187302e-02,   1.74265653e-02,   4.76434082e-03,
        -1.90977417e-02,  -2.13172548e-02,   1.78024359e-02,
         1.66784972e-02,  -1.95194464e-02,  -1.14144757e-03,
        -4.75083292e-03,  -8.06572475e-03,   1.59807056e-02,
        -2.16284636e-02,  -9.59825609e-03,   1.11175627e-02,
         2.97047570e-03,  -3.44133377e-03,   4.08265367e-03,
        -4.69078682e-03,  -5.46276569e-05,  -1.29350992e-02,
        -9.14853346e-03,   3.77608463e-04,  -1.82975177e-02,
        -2.06317343e-02,   5.65472431e-03,   9.53562558e-04,
         1.97624639e-02,   1.33274980e-02,   1.70712918e-02,
        -1.39260748e-02,   1.20430440e-03,   1.52990036e-02,
         1.99222267e-02,  -5.69302216e-03,   2.31241658e-02,
         1.44271292e-02,  -1.52906468e-02,  -1.16468742e-02,
        -1.54872751e-02,   2.37220228e-02,   6.80477731e-03,
         1.75829120e-02,   1.75383799e-02,  -1.62591226e-03,
         2.39939988e-02,  -2.08101347e-02,  -4.72830422e-03,
        -1.40071111e-02,   9.64150950e-03,  -7.35162757e-03,
         4.70679812e-03,  -1.71105377e-02,  -2.59642862e-03,
         1.55653805e-04,  -9.49554145e-03,  -7.91230984e-03,
        -5.49445115e-03,  -2.32011061e-02,   2.59431265e-03,
         2.36858763e-02,  -3.92580219e-03,  -5.88322431e-03,
         9.08011571e-03,   1.05910562e-03,  -1.05537614e-02,
        -1.98659971e-02,  -7.75912777e-04,   7.82781094e-03,
         6.08453713e-03,   1.99076161e-02,  -3.40596214e-03,
        -1.06713315e-02,  -1.66793019e-02,  -1.60669312e-02,
         7.43065029e-03,   2.23596022e-03,  -9.18757729e-03,
         9.98016819e-03,   1.11559443e-02,  -1.23696635e-02,
         2.36304738e-02,   1.45839341e-03,  -8.04433599e-03,
        -2.29621325e-02,   6.99994713e-03,   1.62947327e-02,
        -1.52833201e-02,  -1.06670177e-02,   4.73672897e-03,
         4.43332642e-03,  -3.50471586e-03,  -2.26330012e-03,
         2.39891633e-02,  -2.05450654e-02,   2.23836191e-02,
        -5.68466261e-04,  -2.36727390e-02,  -6.54226914e-03,
         1.86842978e-02,  -2.49825418e-03,   1.69598982e-02,
        -1.46265673e-02,   1.39764994e-02,   1.24187768e-03,
         1.26068853e-03,   1.09120458e-02,  -2.27538943e-02,
         5.48322685e-03,  -6.14884496e-03,   1.44854039e-02,
        -2.17009187e-02,  -9.90571175e-03,   1.49856880e-02,
        -9.89071745e-03,  -1.75629091e-02,   1.90865248e-03,
         8.36897269e-03,   2.39523649e-02,  -2.28534229e-02,
         2.03119218e-02,   8.92897323e-03,   1.32517777e-02,
        -2.53174454e-04,   1.21902041e-02,   2.33942643e-02,
         1.13995522e-02,  -1.79807059e-02,  -1.18939374e-02,
         4.83300723e-03,   8.43480602e-03,   1.95643716e-02,
        -2.03024894e-02,  -7.41455331e-03,   1.45585835e-02,
        -1.35253947e-02,  -1.16004115e-02], dtype=float32)

Objective function

Now we want to define our objective function. Our objective is to minimize loss function, that is, to minimize the average negative log probability of the target words:

loss=−1N∑i=1Nln⁡ptargeti
This function is already implimented and available in TensorFlow through sequence_loss_by_example so we can just use it here. sequence_loss_by_example is weighted cross-entropy loss for a sequence of logits (per example).

Its arguments:

logits: List of 2D Tensors of shape [batch_size x num_decoder_symbols].
targets: List of 1D batch-sized int32 Tensors of the same length as logits.
weights: List of 1D batch-sized float-Tensors of the same length as logits.


In [41]:
loss = tf.contrib.legacy_seq2seq.sequence_loss_by_example([logits], [tf.reshape(_targets, [-1])],[tf.ones([batch_size * num_steps])])

loss is a 1D batch-sized float Tensor [600x1]: The log-perplexity for each sequence.


In [42]:
session.run(loss, feed_dict)


Out[42]:
array([ 9.20420647,  9.20577335,  9.19770718,  9.19478989,  9.21243382,
        9.21494293,  9.21947575,  9.22757244,  9.22758675,  9.21636772,
        9.20103645,  9.19449234,  9.22175789,  9.20829582,  9.21987534,
        9.22261238,  9.19444084,  9.19685555,  9.22691154,  9.21609497,
        9.22377777,  9.21007824,  9.20543385,  9.20270443,  9.21428204,
        9.21068382,  9.22520828,  9.20246506,  9.19868851,  9.20911407,
        9.2212925 ,  9.21260929,  9.21115303,  9.21098137,  9.21868134,
        9.21260643,  9.21111774,  9.22663784,  9.19542694,  9.22124577,
        9.21371841,  9.20293427,  9.21089554,  9.20351315,  9.21870041,
        9.22543907,  9.21017075,  9.20417595,  9.20265102,  9.2133112 ,
        9.22036934,  9.19654942,  9.20777988,  9.21099377,  9.2257061 ,
        9.21096611,  9.22043228,  9.19632721,  9.20783806,  9.20282078,
        9.19964981,  9.2028656 ,  9.2237215 ,  9.21790695,  9.21510983,
        9.19886875,  9.22486019,  9.20293617,  9.21429253,  9.2215786 ,
        9.22514534,  9.19634056,  9.19962311,  9.21320438,  9.21096325,
        9.21628666,  9.19364643,  9.22359562,  9.20597076,  9.19621468,
        9.20256996,  9.22379494,  9.21220207,  9.20674992,  9.21286297,
        9.2110281 ,  9.19412708,  9.20481777,  9.21700668,  9.22420883,
        9.21402836,  9.22650909,  9.1994915 ,  9.22349644,  9.21412659,
        9.20283794,  9.21102238,  9.22106171,  9.20292377,  9.20884323,
        9.22367382,  9.21342468,  9.22038937,  9.19642639,  9.20784473,
        9.20223427,  9.22004509,  9.20266533,  9.2232399 ,  9.21086693,
        9.19854164,  9.21203041,  9.20379162,  9.21970463,  9.20142078,
        9.2190094 ,  9.21865749,  9.19460773,  9.21101189,  9.21279335,
        9.22197342,  9.21280575,  9.20464039,  9.19717216,  9.20272827,
        9.22537613,  9.21212769,  9.20954895,  9.22738075,  9.1997776 ,
        9.21102238,  9.22382355,  9.21862698,  9.2028141 ,  9.2149334 ,
        9.19594574,  9.22044659,  9.20954418,  9.20270348,  9.21116257,
        9.21906281,  9.19820976,  9.21976376,  9.20242786,  9.20463276,
        9.20276451,  9.2108717 ,  9.21390057,  9.20038509,  9.2190218 ,
        9.20458317,  9.20291519,  9.21092129,  9.19541359,  9.2111311 ,
        9.21418476,  9.21423912,  9.21431446,  9.22615147,  9.2109766 ,
        9.20167351,  9.22479534,  9.20908451,  9.20911312,  9.22554016,
        9.20401001,  9.21083927,  9.21886158,  9.19989586,  9.21428776,
        9.21430588,  9.19838238,  9.21683693,  9.20245075,  9.22003365,
        9.212286  ,  9.21533394,  9.21262169,  9.21365833,  9.19641399,
        9.19419193,  9.2082262 ,  9.21550941,  9.21216297,  9.208395  ,
        9.20496178,  9.21972942,  9.21446705,  9.21296883,  9.22455883,
        9.1949501 ,  9.21695137,  9.22055244,  9.19729137,  9.22016144,
        9.22402668,  9.20158958,  9.224823  ,  9.22719669,  9.22466278,
        9.20284939,  9.22407246,  9.20884609,  9.22093296,  9.21667957,
        9.22361946,  9.20969963,  9.21984386,  9.21697617,  9.20216942,
        9.21105671,  9.20052719,  9.19454193,  9.21272564,  9.20111275,
        9.21289539,  9.20702457,  9.20488453,  9.19894791,  9.21450901,
        9.19551659,  9.1964941 ,  9.22664642,  9.19641018,  9.210989  ,
        9.21065235,  9.21642399,  9.20853043,  9.21987152,  9.22688198,
        9.21278477,  9.21080589,  9.22194767,  9.22480011,  9.21428204,
        9.21657181,  9.21684265,  9.21982098,  9.20258999,  9.20677376,
        9.22043896,  9.22486496,  9.19889832,  9.22458363,  9.22648716,
        9.21188641,  9.2261219 ,  9.22631073,  9.22445583,  9.22542572,
        9.21421909,  9.21630955,  9.22387123,  9.19745922,  9.21421623,
        9.21098232,  9.21083832,  9.19422436,  9.22717762,  9.19864273,
        9.21427917,  9.21105957,  9.21631241,  9.21089268,  9.22660732,
        9.20282555,  9.19810772,  9.19913292,  9.21175861,  9.22374439,
        9.22190285,  9.21096039,  9.21264362,  9.21077538,  9.20790672,
        9.20310783,  9.22347736,  9.21443081,  9.21444798,  9.21445084,
        9.21971607,  9.20259666,  9.20176411,  9.21271992,  9.22131729,
        9.22745609,  9.21851635,  9.22369289,  9.21360397,  9.21115875,
        9.22175026,  9.21086979,  9.20043945,  9.19713974,  9.21523952,
        9.21272945,  9.21087551,  9.21438789,  9.19538689,  9.20816612,
        9.22001648,  9.21983624,  9.22360039,  9.1999054 ,  9.21083546,
        9.22194195,  9.20989227,  9.21076679,  9.21079445,  9.22084904,
        9.21862411,  9.20283508,  9.1957159 ,  9.22666454,  9.20697594,
        9.21084595,  9.20883846,  9.2112484 ,  9.21114159,  9.19487476,
        9.21354008,  9.20010662,  9.20306683,  9.2207365 ,  9.21721935,
        9.19930077,  9.19505024,  9.21125698,  9.19325638,  9.22000504,
        9.20708752,  9.22071075,  9.21291828,  9.21097565,  9.21814728,
        9.21858311,  9.22161484,  9.19798279,  9.21609306,  9.22385216,
        9.21418095,  9.21105194,  9.22471142,  9.22412109,  9.2043066 ,
        9.19688988,  9.22003841,  9.20054531,  9.19444084,  9.21278763,
        9.203125  ,  9.2108345 ,  9.20475674,  9.21102905,  9.21219063,
        9.22065926,  9.20991802,  9.21213913,  9.22084522,  9.21117687,
        9.19951057,  9.22727203,  9.21122932,  9.22421837,  9.19551754,
        9.20456409,  9.22784996,  9.22461033,  9.223773  ,  9.21427441,
        9.20698929,  9.22384644,  9.21223545,  9.2200613 ,  9.22363091,
        9.21439934,  9.21442986,  9.21107101,  9.19480705,  9.22471523,
        9.21092129,  9.21970463,  9.19644737,  9.20577908,  9.21106148,
        9.21226597,  9.22467995,  9.2135582 ,  9.22753716,  9.19999886,
        9.22570801,  9.22474098,  9.21431446,  9.21435356,  9.21269894,
        9.19632244,  9.20470047,  9.21263695,  9.22559929,  9.2169199 ,
        9.20761585,  9.20283222,  9.21426773,  9.21609497,  9.19695473,
        9.19874096,  9.21025658,  9.21992874,  9.22215557,  9.22673512,
        9.22612476,  9.21404362,  9.21098042,  9.21979046,  9.20807838,
        9.22177696,  9.22042465,  9.1964159 ,  9.20781136,  9.20277309,
        9.21411991,  9.218297  ,  9.19809532,  9.21265125,  9.19514179,
        9.19884396,  9.22204494,  9.22030354,  9.20478153,  9.2264328 ,
        9.21141148,  9.22245216,  9.21376133,  9.21111393,  9.2200985 ,
        9.20054626,  9.21999168,  9.19898224,  9.2270956 ,  9.21626663,
        9.20011616,  9.1944828 ,  9.22192764,  9.20949936,  9.21084404,
        9.21433735,  9.19639492,  9.21091652,  9.20312214,  9.22069359,
        9.221241  ,  9.20541954,  9.21430492,  9.21108437,  9.22204018,
        9.22081947,  9.20650482,  9.19887543,  9.2237339 ,  9.22413635,
        9.20449257,  9.20556164,  9.22728348,  9.21774292,  9.21104622,
        9.19664001,  9.2222538 ,  9.19368935,  9.20439434,  9.20623302,
        9.20261288,  9.21841717,  9.19601059,  9.20782375,  9.20269871,
        9.19635582,  9.19632244,  9.22659492,  9.2186718 ,  9.22645855,
        9.22407818,  9.21437168,  9.21092033,  9.21113586,  9.20318413,
        9.21435833,  9.21107292,  9.21833801,  9.21167278,  9.21221542,
        9.20137978,  9.21408653,  9.19821453,  9.21104813,  9.21407032,
        9.21021652,  9.22406864,  9.21269512,  9.21934795,  9.22368813,
        9.22669315,  9.21704483,  9.22618675,  9.2275095 ,  9.21618843,
        9.20898914,  9.21116638,  9.22081661,  9.21410179,  9.21071148,
        9.20818424,  9.20185184,  9.20290661,  9.21076012,  9.20226765,
        9.214468  ,  9.2107296 ,  9.22440243,  9.21044731,  9.20544052,
        9.21266842,  9.20977116,  9.20062542,  9.22651863,  9.20812702,
        9.21624279,  9.2078104 ,  9.21463776,  9.21169662,  9.20409775,
        9.22041225,  9.19636154,  9.20774555,  9.20285416,  9.21296597,
        9.21858406,  9.19896698,  9.22645664,  9.19713306,  9.20236874,
        9.20859623,  9.22179317,  9.202878  ,  9.22032452,  9.19367123,
        9.21871471,  9.21264648,  9.2163496 ,  9.20523167,  9.19640636,
        9.21868706,  9.21865559,  9.22040939,  9.19633961,  9.20763588,
        9.19808483,  9.22059727,  9.21104908,  9.1958704 ,  9.20275116,
        9.2252121 ,  9.20281124,  9.2112875 ,  9.21419048,  9.21079731,
        9.22663879,  9.22022247,  9.21682644,  9.22265339,  9.21986485,
        9.21680641,  9.19416714,  9.21992683,  9.21712494,  9.19954491,
        9.20378113,  9.21976566,  9.22378731,  9.21103573,  9.19931507,
        9.22149658,  9.2206583 ,  9.22398281,  9.20695019,  9.21020985,
        9.22399235,  9.1946907 ,  9.2235899 ,  9.22050095,  9.20793438,
        9.21102428,  9.21203327,  9.20169258,  9.19778728,  9.19744396,
        9.19776058,  9.20174026,  9.21416378,  9.2191658 ,  9.19989872], dtype=float32)

In [43]:
cost = tf.reduce_sum(loss) / batch_size

session.run(tf.global_variables_initializer())
session.run(cost, feed_dict)


Out[43]:
184.20708

Now, lets store the new state as final state


In [44]:
# 
final_state = new_state

Training

To do gradient clipping in TensorFlow we have to take the following steps:

  1. Define the optimizer.
  2. Extract variables that are trainable.
  3. Calculate the gradients based on the loss function.
  4. Apply the optimizer to the variables / gradients tuple.

1. Define Optimizer

GradientDescentOptimizer constructs a new gradient descent optimizer. Later, we use constructed optimizer to compute gradients for a loss and apply gradients to variables.


In [45]:
# Create a variable for the learning rate
lr = tf.Variable(0.0, trainable=False)
# Create the gradient descent optimizer with our learning rate
optimizer = tf.train.GradientDescentOptimizer(lr)

2. Trainable Variables

Definining a variable, if you passed trainable=True, the Variable() constructor automatically adds new variables to the graph collection GraphKeys.TRAINABLE_VARIABLES. Now, using _tf.trainablevariables() you can get all variables created with trainable=True.


In [46]:
# Get all TensorFlow variables marked as "trainable" (i.e. all of them except _lr, which we just created)
tvars = tf.trainable_variables()
tvars


Out[46]:
[<tf.Variable 'embedding:0' shape=(10000, 200) dtype=float32_ref>,
 <tf.Variable 'rnn/multi_rnn_cell/cell_0/basic_lstm_cell/weights:0' shape=(400, 800) dtype=float32_ref>,
 <tf.Variable 'rnn/multi_rnn_cell/cell_0/basic_lstm_cell/biases:0' shape=(800,) dtype=float32_ref>,
 <tf.Variable 'rnn/multi_rnn_cell/cell_1/basic_lstm_cell/weights:0' shape=(400, 800) dtype=float32_ref>,
 <tf.Variable 'rnn/multi_rnn_cell/cell_1/basic_lstm_cell/biases:0' shape=(800,) dtype=float32_ref>,
 <tf.Variable 'softmax_w:0' shape=(200, 10000) dtype=float32_ref>,
 <tf.Variable 'softmax_b:0' shape=(10000,) dtype=float32_ref>]

we can find the name and scope of all variables:


In [47]:
tvars=tvars[3:]

In [48]:
[v.name for v in tvars]


Out[48]:
['rnn/multi_rnn_cell/cell_1/basic_lstm_cell/weights:0',
 'rnn/multi_rnn_cell/cell_1/basic_lstm_cell/biases:0',
 'softmax_w:0',
 'softmax_b:0']

3. Calculate the gradients based on the loss function


In [49]:
cost


Out[49]:
<tf.Tensor 'truediv:0' shape=() dtype=float32>

In [50]:
tvars


Out[50]:
[<tf.Variable 'rnn/multi_rnn_cell/cell_1/basic_lstm_cell/weights:0' shape=(400, 800) dtype=float32_ref>,
 <tf.Variable 'rnn/multi_rnn_cell/cell_1/basic_lstm_cell/biases:0' shape=(800,) dtype=float32_ref>,
 <tf.Variable 'softmax_w:0' shape=(200, 10000) dtype=float32_ref>,
 <tf.Variable 'softmax_b:0' shape=(10000,) dtype=float32_ref>]

Gradient:

The gradient of a function (line) is the slope of the line, or the rate of change of a function. It's a vector (a direction to move) that points in the direction of greatest increase of the function, and calculated by derivative operation.

First lets recall the gradient function using an toy example: $$ z=\left(2x^2+3xy\right)$$


In [51]:
var_x = tf.placeholder(tf.float32)
var_y = tf.placeholder(tf.float32) 
func_test = 2.0*var_x*var_x + 3.0*var_x*var_y
session.run(tf.global_variables_initializer())
feed={var_x:1.0,var_y:2.0}
session.run(func_test, feed)


Out[51]:
8.0

The tf.gradients() function allows you to compute the symbolic gradient of one tensor with respect to one or more other tensors—including variables. tf.gradients(func,xs) constructs symbolic partial derivatives of sum of func w.r.t. x in xs.

Now, lets look at the derivitive w.r.t. var_x: $$ \frac{\partial \:}{\partial \:x}\left(2x^2+3xy\right)=4x+3y $$


In [52]:
var_grad = tf.gradients(func_test, [var_x])
session.run(var_grad,feed)


Out[52]:
[10.0]

the derivitive w.r.t. var_y: $$ \frac{\partial \:}{\partial \:x}\left(2x^2+3xy\right)=3x $$


In [53]:
var_grad = tf.gradients(func_test, [var_y])
session.run(var_grad,feed)


Out[53]:
[3.0]

Now, we can look at gradients w.r.t all variables:


In [54]:
tf.gradients(cost, tvars)


Out[54]:
[<tf.Tensor 'gradients_2/rnn/while/multi_rnn_cell/cell_1/basic_lstm_cell/basic_lstm_cell_1/MatMul/Enter_grad/b_acc_3:0' shape=(400, 800) dtype=float32>,
 <tf.Tensor 'gradients_2/rnn/while/multi_rnn_cell/cell_1/basic_lstm_cell/basic_lstm_cell_1/BiasAdd/Enter_grad/b_acc_3:0' shape=(800,) dtype=float32>,
 <tf.Tensor 'gradients_2/MatMul_grad/MatMul_1:0' shape=(200, 10000) dtype=float32>,
 <tf.Tensor 'gradients_2/add_grad/Reshape_1:0' shape=(10000,) dtype=float32>]

In [55]:
grad_t_list = tf.gradients(cost, tvars)
#sess.run(grad_t_list,feed_dict)

now, we have a list of tensors, t-list. We can use it to find clipped tensors. clip_by_global_norm clips values of multiple tensors by the ratio of the sum of their norms.

clip_by_global_norm get t-list as input and returns 2 things:

  • a list of clipped tensors, so called _listclipped
  • the global norm (global_norm) of all tensors in t_list

In [56]:
max_grad_norm


Out[56]:
5

In [57]:
# Define the gradient clipping threshold
grads, _ = tf.clip_by_global_norm(grad_t_list, max_grad_norm)
grads


Out[57]:
[<tf.Tensor 'clip_by_global_norm/clip_by_global_norm/_0:0' shape=(400, 800) dtype=float32>,
 <tf.Tensor 'clip_by_global_norm/clip_by_global_norm/_1:0' shape=(800,) dtype=float32>,
 <tf.Tensor 'clip_by_global_norm/clip_by_global_norm/_2:0' shape=(200, 10000) dtype=float32>,
 <tf.Tensor 'clip_by_global_norm/clip_by_global_norm/_3:0' shape=(10000,) dtype=float32>]

In [58]:
session.run(grads,feed_dict)


Out[58]:
[array([[  1.07558318e-08,   9.12204889e-09,   1.12856116e-08, ...,
          -7.70966224e-09,   8.92476137e-09,  -8.08427369e-09],
        [ -6.10435347e-09,   9.43148670e-09,  -1.03863957e-08, ...,
          -1.03984368e-08,   1.23483987e-08,  -8.44202575e-09],
        [ -5.72466208e-09,  -1.07929443e-09,   4.95039343e-09, ...,
          -6.85579460e-09,   8.23519297e-09,   7.05244751e-10],
        ..., 
        [  2.30261255e-09,  -1.12141718e-09,  -1.58951663e-09, ...,
           2.88229907e-09,   6.81134316e-12,   9.39832101e-10],
        [ -1.38669620e-09,  -1.34926359e-09,   1.63711478e-09, ...,
           2.30973085e-09,  -9.56553237e-09,   1.31781841e-09],
        [  1.11014176e-09,  -6.27862784e-09,   2.11078355e-09, ...,
           2.88673974e-09,  -2.77861556e-09,  -5.81890203e-10]], dtype=float32),
 array([ -9.82914344e-07,  -2.70246005e-06,   1.43185162e-06,
         -9.93417899e-08,   3.67005305e-06,   7.76386787e-06,
          3.61585080e-07,   6.92272351e-07,  -9.32212151e-06,
          5.18600609e-06,  -4.02639125e-06,   5.18167371e-06,
          3.17013087e-06,   7.17835701e-06,  -3.41107079e-06,
         -1.95995517e-06,  -5.75463810e-06,  -8.08337063e-06,
         -9.58691817e-06,  -8.18853465e-08,  -6.32245610e-06,
          4.60057026e-06,   1.59306751e-06,   4.29707779e-06,
          1.82323299e-06,   2.90430148e-06,   3.32796958e-06,
         -1.01870444e-06,   2.77856930e-06,   3.59660930e-06,
         -2.36861524e-06,   6.73505883e-06,  -4.50051584e-06,
         -4.52955646e-07,   5.53370228e-06,   3.60019999e-06,
          3.83563292e-06,  -1.93417145e-06,  -6.10211464e-06,
          6.91721254e-07,  -1.42459089e-06,  -5.49618221e-07,
          6.64596087e-07,   1.41093949e-06,   3.41143095e-06,
         -1.12702628e-06,   1.42534236e-06,   6.63407718e-07,
         -1.07590552e-06,  -8.79912022e-07,  -6.29792794e-07,
          5.92162678e-06,  -2.42609622e-06,   7.91492039e-06,
         -1.58074170e-06,   3.03377897e-06,  -3.23260707e-07,
          1.37384279e-06,  -1.43062130e-06,   7.24143410e-06,
         -4.07794573e-07,  -3.59753017e-07,  -5.98142378e-06,
         -6.92211074e-07,   4.07735706e-06,   5.02325884e-06,
          1.22761389e-06,  -6.51060282e-06,   1.29732393e-06,
          3.51321091e-06,   5.27950397e-06,  -3.25940704e-07,
          1.78458777e-08,   3.83401448e-06,   1.61028356e-06,
         -1.90295020e-06,  -1.35942037e-06,  -4.88180785e-06,
         -4.69322868e-06,   2.78170933e-06,  -3.12692850e-06,
          7.84308895e-07,  -3.69510076e-07,  -1.92276389e-06,
          3.04595073e-06,  -5.18197453e-07,   5.65314849e-07,
         -9.00348459e-06,  -9.66881402e-08,   3.18885691e-06,
          6.76951868e-06,  -1.72842795e-06,  -2.24759447e-06,
         -3.34266633e-06,   5.71129249e-07,  -3.99725786e-06,
         -2.14304873e-06,   1.37092602e-05,  -4.30679574e-06,
          3.01267892e-06,   1.45350612e-06,  -1.86071588e-07,
         -3.03869115e-06,   8.53636038e-07,   4.74442140e-06,
         -1.44282603e-06,  -4.53262737e-06,   1.23572011e-06,
         -1.66912184e-06,   3.73479156e-06,  -5.29677209e-08,
         -1.64481287e-06,   3.32753416e-07,   7.03130763e-06,
          1.62235176e-06,   3.17398803e-06,  -6.65567632e-06,
         -3.23337957e-07,   3.01984664e-06,  -8.66141534e-07,
         -2.01093121e-06,  -2.70194437e-06,   1.02681497e-05,
         -3.63040999e-06,   8.02593377e-07,  -3.25779047e-06,
         -1.82838971e-06,  -2.99340491e-06,   1.20133200e-06,
          6.03289254e-06,   2.42753777e-06,  -1.77725695e-07,
          9.32380317e-07,  -4.86792487e-07,  -1.93151934e-07,
          4.64627647e-07,  -4.40223857e-09,  -1.23695406e-06,
         -9.02738338e-06,   6.45926775e-07,  -2.91312131e-06,
          3.35279628e-06,   5.40411656e-06,   4.60168258e-06,
          3.66765516e-06,   2.35646712e-06,  -8.42410884e-07,
          2.96056351e-06,   1.87226783e-06,   6.46565240e-06,
          4.95423365e-06,   1.27702788e-06,  -2.95620862e-06,
          1.02328454e-06,  -1.02693798e-06,  -3.47693435e-06,
          1.08739164e-06,  -5.35011714e-06,   1.06507514e-06,
          1.12009786e-07,   1.13090455e-05,   8.79542540e-07,
          3.23114045e-06,  -8.71161490e-07,   3.11542635e-06,
          1.54541306e-06,  -1.44670764e-07,   4.61367927e-06,
         -2.38717007e-06,   9.04529497e-06,   5.73928764e-06,
         -2.34789582e-07,   2.85703572e-06,   2.43482464e-06,
          6.40546602e-08,  -2.86960500e-07,  -2.24453788e-06,
         -7.07047889e-07,   6.33221498e-06,  -4.48248898e-07,
          5.61412526e-06,   6.64692743e-06,   7.60185003e-07,
         -1.36749111e-06,   5.71082637e-06,  -1.82961571e-06,
          2.90985639e-07,  -9.60018497e-07,  -2.10302596e-06,
          5.43712190e-07,  -3.66999188e-07,   9.61188789e-07,
         -4.96082976e-07,   4.26015913e-06,   1.30067838e-06,
         -1.18485536e-06,  -4.39124824e-06,   6.18703143e-06,
         -7.26526059e-06,   6.26114888e-06,   1.57416351e-02,
         -5.70137519e-03,  -8.32606480e-03,  -1.71826836e-02,
          1.29635986e-02,   8.61287210e-03,  -2.60825921e-03,
         -1.09919980e-02,   1.76027324e-02,   9.02573578e-03,
         -2.66098697e-02,  -2.69382037e-02,   5.10471547e-03,
         -1.22730806e-02,   2.87348637e-04,   4.47437307e-03,
         -1.11619551e-02,   1.95030887e-02,  -1.79159120e-02,
         -9.49632260e-04,   2.11305581e-02,  -4.93671698e-03,
         -3.11563723e-03,   2.50374177e-03,  -6.34684879e-03,
          8.38904828e-03,   1.24013657e-02,  -1.10987490e-02,
         -1.81399174e-02,   3.26805306e-03,   6.83602598e-03,
          7.55474064e-03,   3.29519100e-02,   7.69233098e-04,
          1.25834113e-03,   1.78829674e-02,  -4.49220184e-03,
          4.82630730e-03,   5.99692622e-03,   1.28630884e-02,
         -2.19132472e-03,   8.76293425e-03,   5.11210598e-03,
         -2.48940056e-03,  -7.38183316e-03,  -1.46293384e-03,
         -1.33264270e-02,  -2.35242164e-03,   1.30813650e-03,
         -5.40518248e-03,   1.38473674e-03,  -3.48631255e-02,
          3.07519617e-03,   1.59400944e-02,  -3.50368842e-02,
         -5.31282183e-03,   1.73189789e-02,  -2.70525198e-02,
         -1.04119154e-02,   1.96446665e-02,  -1.53494086e-02,
          5.98736014e-03,   2.62103393e-03,  -9.63704009e-03,
          6.80550653e-03,   3.35142128e-02,  -2.22280007e-02,
          1.06993932e-02,   1.37375128e-02,   1.13827391e-02,
         -2.23144870e-02,  -1.24035226e-02,   1.66220944e-02,
         -4.74174973e-03,  -2.27229693e-03,   7.72645976e-03,
         -2.05013603e-02,  -2.70060822e-03,   2.64816657e-02,
         -4.93248226e-03,   1.66080818e-02,  -2.22585090e-02,
          4.96541755e-03,   7.88338948e-04,   1.07331853e-02,
          1.64957270e-02,  -4.84225247e-03,   4.59300121e-03,
         -8.35785642e-03,  -2.87509076e-02,  -1.84187163e-02,
          6.44925376e-03,  -1.00844614e-02,   1.53473699e-02,
         -1.39094675e-02,  -1.83860995e-02,  -2.94688973e-03,
         -5.50228730e-02,   1.19595006e-02,   4.46464680e-03,
         -9.42728110e-03,  -4.87236977e-02,   2.28651017e-02,
          1.50827626e-02,  -5.84366685e-03,   8.41544755e-03,
         -4.61123139e-03,   1.40015474e-02,   7.37676863e-03,
          2.35705357e-03,  -2.33050939e-02,   6.38761884e-03,
         -2.40713381e-03,  -6.00346038e-03,   9.80745349e-03,
         -2.30445825e-02,  -1.18204355e-02,   1.36576183e-02,
         -1.71943642e-02,   1.01556471e-02,  -3.03011294e-03,
         -1.75217185e-02,  -2.47788727e-02,   4.81905742e-03,
         -4.69958317e-03,  -1.62465572e-02,  -4.59224917e-03,
          9.07193590e-03,  -2.37143300e-02,  -1.13677671e-02,
          1.27614755e-02,   2.65831631e-02,   1.27236908e-02,
         -1.03617497e-02,  -2.54505663e-03,  -1.12673454e-02,
         -2.08462570e-02,   2.02535335e-02,   1.04349954e-02,
         -1.37338042e-03,  -1.06596947e-02,   9.60739329e-03,
          1.81898177e-02,   3.72051932e-02,  -3.62525554e-03,
          1.20648800e-03,   1.03056571e-02,  -2.96262302e-03,
          1.44676361e-02,   3.07724625e-03,  -1.01522719e-02,
         -1.74799804e-02,  -1.49761084e-02,  -5.35813160e-03,
         -2.38989084e-03,  -3.11703188e-03,   4.35823482e-03,
         -1.05286287e-02,  -4.45173355e-03,  -5.55665093e-03,
         -2.55468581e-02,   1.30236614e-02,   2.94794571e-02,
          3.52239306e-03,   9.95416427e-04,   1.00033684e-02,
          1.81564689e-02,   2.18876544e-02,   9.30782594e-03,
         -2.51763929e-02,   9.89464764e-03,  -4.73557133e-03,
          2.91985669e-03,   2.11258717e-02,  -1.46586196e-02,
          4.45122924e-03,   7.49312760e-03,  -1.66451596e-02,
         -2.41063312e-02,  -3.20085487e-03,   1.46503178e-02,
         -2.40885671e-02,   1.77806374e-02,  -1.05233127e-02,
         -3.47697805e-03,  -1.37190130e-02,  -1.17756543e-04,
         -1.47971883e-02,  -4.32201102e-03,  -2.13395990e-02,
         -8.33382551e-03,   4.35258681e-03,   4.82584164e-03,
         -1.48413479e-02,   6.54272921e-03,  -1.28346086e-02,
         -4.05584555e-03,   8.30948353e-04,  -1.85767878e-02,
         -9.76343732e-03,   2.60521290e-07,   1.20860102e-06,
          2.22391259e-06,  -1.62903277e-07,   4.68082874e-07,
          4.89983722e-06,   1.71245574e-06,   2.32184084e-06,
         -9.27553174e-06,   5.98117003e-06,   1.12948669e-07,
         -3.28783898e-07,   1.70923170e-06,   8.20562673e-06,
         -3.71944157e-06,  -4.33844548e-07,  -4.86704766e-06,
         -5.30711031e-06,  -5.77778019e-06,  -3.24991538e-06,
         -3.54241411e-06,   5.10643895e-06,  -1.45491197e-06,
          2.07935454e-06,   1.09212851e-06,   5.82683299e-07,
          1.49698019e-06,   1.08757035e-06,   2.45951151e-06,
          2.75505249e-06,  -3.07968691e-07,   4.71213025e-06,
         -7.26830649e-07,  -1.56574015e-06,   2.37511131e-06,
          2.36189476e-06,   3.61814045e-07,  -1.19658239e-06,
         -4.04474167e-06,   3.03026081e-06,  -2.95641507e-06,
          2.17118986e-06,   1.86942702e-06,   9.46173316e-07,
          2.98424698e-06,   1.22396818e-06,   2.56443809e-06,
         -1.54520069e-06,  -2.11724364e-06,  -1.89466436e-06,
          2.26334510e-06,   6.55188558e-07,  -1.30746537e-06,
          7.51045263e-06,  -4.59089915e-06,   3.33243588e-06,
         -1.42537522e-06,   2.34829699e-07,  -5.10519521e-06,
          6.26491783e-06,  -1.60390482e-06,   2.41032217e-06,
         -2.20180959e-06,  -4.97622204e-06,   4.04844104e-06,
          2.75441539e-06,   2.00862974e-06,  -2.34681238e-06,
          2.69384827e-06,   4.87354419e-06,   2.94390884e-06,
          1.12253545e-07,  -8.63815046e-07,   2.21215464e-06,
         -4.45400502e-07,   8.53906670e-07,  -1.69445059e-07,
         -3.28963893e-06,  -3.89645447e-06,   9.63742650e-07,
         -4.68929056e-06,  -2.40273175e-06,  -1.12726440e-07,
         -9.84741973e-07,   3.86574175e-06,  -2.20387187e-06,
         -2.24701694e-07,  -4.37845119e-06,  -1.41748387e-06,
          1.02776414e-06,   3.16587659e-07,  -3.88201016e-07,
         -2.85061378e-06,  -2.57348961e-06,  -4.32750630e-06,
         -2.39613451e-06,   1.29266311e-06,   1.21684034e-05,
         -1.91041136e-06,   2.46969648e-06,   7.31945988e-07,
          8.60080831e-07,  -4.38720872e-06,  -2.24387759e-06,
          5.31117576e-06,  -2.09112599e-07,  -4.16255659e-07,
          2.16233775e-06,  -1.15601756e-06,   2.50033258e-06,
         -1.15548573e-06,   3.08597492e-07,   1.47365017e-06,
          4.59414514e-06,   3.86532929e-06,   1.29960551e-06,
         -1.21904486e-06,   1.03998809e-06,   2.54141764e-06,
          1.74466237e-08,  -1.41468399e-06,  -1.61374396e-06,
          5.76730872e-06,   4.60800891e-07,  -6.76586410e-07,
         -1.29788162e-07,  -5.58051454e-08,  -1.53132532e-06,
          3.34247011e-06,   6.85695068e-06,   6.69626502e-07,
          8.33473734e-07,  -1.36634668e-07,  -9.62048375e-07,
         -7.35496997e-08,   3.74739642e-07,  -3.13784028e-07,
         -1.32659409e-06,  -6.73950808e-06,   8.63073410e-07,
         -1.24259327e-06,   2.62614890e-06,   4.50059679e-06,
          2.75691127e-06,   1.88759577e-06,   2.38617730e-07,
         -2.27284272e-06,   5.30344721e-07,   3.75141610e-07,
          1.94627091e-06,   5.14863814e-06,  -1.45993283e-06,
         -2.34429808e-06,   2.25300596e-06,  -7.93926972e-07,
         -1.00608509e-06,   4.69127008e-06,  -4.66633173e-06,
          1.21458424e-06,   6.55312761e-07,   7.96667609e-06,
         -2.72646025e-06,   4.56059888e-06,   1.40484474e-06,
          3.82481886e-07,  -5.54993846e-07,   2.00310888e-06,
          2.16850844e-06,  -5.01217619e-06,   6.09632161e-06,
          6.03021408e-06,  -9.31952854e-07,   2.30514911e-06,
          1.19421520e-06,   4.71063458e-07,  -3.23353277e-07,
         -1.21577534e-06,   6.90139700e-07,   3.57194222e-06,
         -2.50950188e-06,   3.54038252e-06,   8.25590632e-06,
         -2.01037506e-06,  -6.60425314e-07,   3.88836224e-06,
          1.00033799e-06,   1.50739947e-06,  -2.21034583e-07,
         -2.58395204e-07,  -2.92916383e-07,  -2.58334262e-06,
          6.55429290e-07,  -5.44981106e-07,   3.73820399e-06,
         -3.82341341e-07,  -5.58215788e-07,  -1.68881320e-06,
          5.88920511e-06,  -9.71646205e-06,   3.37362667e-06,
         -9.81034077e-07,  -2.70119631e-06,   1.43647139e-06,
         -1.02863162e-07,   3.66466429e-06,   7.76073011e-06,
          3.64194705e-07,   6.96642019e-07,  -9.31534851e-06,
          5.19152763e-06,  -4.02430715e-06,   5.17939861e-06,
          3.16733190e-06,   7.17361627e-06,  -3.41506484e-06,
         -1.95934581e-06,  -5.75211243e-06,  -8.08480490e-06,
         -9.58528926e-06,  -8.83056259e-08,  -6.32138745e-06,
          4.59868488e-06,   1.59298384e-06,   4.29719921e-06,
          1.82008830e-06,   2.90500952e-06,   3.32539253e-06,
         -1.02450394e-06,   2.78040579e-06,   3.60012768e-06,
         -2.36517667e-06,   6.73762634e-06,  -4.50151538e-06,
         -4.56661127e-07,   5.53595055e-06,   3.59374440e-06,
          3.83099223e-06,  -1.93794381e-06,  -6.09749213e-06,
          6.94846449e-07,  -1.42719341e-06,  -5.52956010e-07,
          6.65826008e-07,   1.40890256e-06,   3.41212990e-06,
         -1.12820317e-06,   1.42424767e-06,   6.62094067e-07,
         -1.07548135e-06,  -8.80826292e-07,  -6.27227905e-07,
          5.91996331e-06,  -2.42518081e-06,   7.91937873e-06,
         -1.58035527e-06,   3.03137858e-06,  -3.30963815e-07,
          1.37155416e-06,  -1.43128318e-06,   7.24530219e-06,
         -4.08240226e-07,  -3.58890986e-07,  -5.98219412e-06,
         -6.88504144e-07,   4.07738571e-06,   5.02703597e-06,
          1.22890344e-06,  -6.50970605e-06,   1.29917953e-06,
          3.51349036e-06,   5.28424061e-06,  -3.21407470e-07,
          1.90845526e-08,   3.83370434e-06,   1.61200808e-06,
         -1.90226376e-06,  -1.36010954e-06,  -4.88218757e-06,
         -4.68597773e-06,   2.78562243e-06,  -3.12706084e-06,
          7.87378326e-07,  -3.66204006e-07,  -1.92720972e-06,
          3.05106528e-06,  -5.16713555e-07,   5.66214965e-07,
         -8.99979386e-06,  -9.68429674e-08,   3.18857337e-06,
          6.76867876e-06,  -1.72695570e-06,  -2.24906603e-06,
         -3.34063770e-06,   5.74892738e-07,  -3.99767896e-06,
         -2.14366173e-06,   1.37154211e-05,  -4.31116223e-06,
          3.01310456e-06,   1.45008812e-06,  -1.78946564e-07,
         -3.04220112e-06,   8.54783536e-07,   4.74377248e-06,
         -1.44498392e-06,  -4.53373923e-06,   1.24134681e-06,
         -1.66654297e-06,   3.73432908e-06,  -5.03070225e-08,
         -1.64680591e-06,   3.29330305e-07,   7.03859632e-06,
          1.62218612e-06,   3.17778881e-06,  -6.65701873e-06,
         -3.21909141e-07,   3.01879072e-06,  -8.70601468e-07,
         -2.01203306e-06,  -2.70245141e-06,   1.02647700e-05,
         -3.63168283e-06,   8.00544569e-07,  -3.25968267e-06,
         -1.82955273e-06,  -2.99339240e-06,   1.20003926e-06,
          6.03243689e-06,   2.43145564e-06,  -1.79140557e-07,
          9.33514571e-07,  -4.89771082e-07,  -1.90595259e-07,
          4.64639982e-07,  -6.27326813e-09,  -1.23806262e-06,
         -9.02605734e-06,   6.48135710e-07,  -2.91203673e-06,
          3.35158848e-06,   5.40667043e-06,   4.60245519e-06,
          3.67262214e-06,   2.35333255e-06,  -8.42537247e-07,
          2.95811265e-06,   1.87132434e-06,   6.46657918e-06,
          4.95950280e-06,   1.27379235e-06,  -2.95318046e-06,
          1.01936052e-06,  -1.02815079e-06,  -3.47397349e-06,
          1.08503207e-06,  -5.34318542e-06,   1.06169603e-06,
          1.07210269e-07,   1.13128508e-05,   8.77350601e-07,
          3.24298276e-06,  -8.66554103e-07,   3.11561621e-06,
          1.54698887e-06,  -1.43378031e-07,   4.61102081e-06,
         -2.38901589e-06,   9.04175613e-06,   5.73906755e-06,
         -2.36117501e-07,   2.85208739e-06,   2.43185605e-06,
          6.28179748e-08,  -2.84915700e-07,  -2.24331188e-06,
         -7.13488419e-07,   6.32976980e-06,  -4.48875142e-07,
          5.60653689e-06,   6.64409981e-06,   7.65794653e-07,
         -1.36606752e-06,   5.71393321e-06,  -1.83165832e-06,
          2.91180356e-07,  -9.60992224e-07,  -2.10125791e-06,
          5.40220356e-07,  -3.65295051e-07,   9.61008141e-07,
         -4.98835334e-07,   4.26630868e-06,   1.29872717e-06,
         -1.18461230e-06,  -4.38786537e-06,   6.18696367e-06,
         -7.26996086e-06,   6.26267411e-06], dtype=float32),
 array([[ -2.83102017e-05,  -8.97064601e-05,  -1.83318887e-04, ...,
           1.69379302e-08,   1.68962444e-08,   1.70057284e-08],
        [  3.84476289e-06,  -1.96686146e-04,  -2.30085439e-04, ...,
           1.00376987e-07,   1.00203515e-07,   1.00678115e-07],
        [ -1.82960692e-04,  -2.06380268e-04,  -2.30986130e-04, ...,
           3.39115502e-07,   3.38635346e-07,   3.40197488e-07],
        ..., 
        [ -1.21578116e-04,  -2.20045884e-04,  -3.38335594e-05, ...,
           6.16170368e-08,   6.15108178e-08,   6.17992058e-08],
        [ -3.70334717e-04,  -3.58534860e-04,  -3.39081569e-04, ...,
           6.08545804e-07,   6.07646143e-07,   6.10526229e-07],
        [ -1.03056460e-04,  -1.56146940e-04,   1.57488874e-04, ...,
           1.30574549e-07,   1.30326242e-07,   1.30917471e-07]], dtype=float32),
 array([-0.79803026, -1.03130364, -1.0312978 , ...,  0.00199627,
         0.0019934 ,  0.00200284], dtype=float32)]

4. Apply the optimizer to the variables / gradients tuple.


In [59]:
# Create the training TensorFlow Operation through our optimizer
train_op = optimizer.apply_gradients(zip(grads, tvars))

In [60]:
session.run(tf.global_variables_initializer())
session.run(train_op,feed_dict)

We learned how the model is build step by step. Noe, let's then create a Class that represents our model. This class needs a few things:

  • We have to create the model in accordance with our defined hyperparameters
  • We have to create the placeholders for our input data and expected outputs (the real data)
  • We have to create the LSTM cell structure and connect them with our RNN structure
  • We have to create the word embeddings and point them to the input data
  • We have to create the input structure for our RNN
  • We have to instanciate our RNN model and retrieve the variable in which we should expect our outputs to appear
  • We need to create a logistic structure to return the probability of our words
  • We need to create the loss and cost functions for our optimizer to work, and then create the optimizer
  • And finally, we need to create a training operation that can be run to actually train our model

In [61]:
class PTBModel(object):

    def __init__(self, is_training):
        ######################################
        # Setting parameters for ease of use #
        ######################################
        self.batch_size = batch_size
        self.num_steps = num_steps
        size = hidden_size
        self.vocab_size = vocab_size
        
        ###############################################################################
        # Creating placeholders for our input data and expected outputs (target data) #
        ###############################################################################
        self._input_data = tf.placeholder(tf.int32, [batch_size, num_steps]) #[30#20]
        self._targets = tf.placeholder(tf.int32, [batch_size, num_steps]) #[30#20]

        ##########################################################################
        # Creating the LSTM cell structure and connect it with the RNN structure #
        ##########################################################################
        # Create the LSTM unit. 
        # This creates only the structure for the LSTM and has to be associated with a RNN unit still.
        # The argument n_hidden(size=200) of BasicLSTMCell is size of hidden layer, that is, the number of hidden units of the LSTM (inside A).
        # Size is the same as the size of our hidden layer, and no bias is added to the Forget Gate. 
        # LSTM cell processes one word at a time and computes probabilities of the possible continuations of the sentence.
        lstm_cells = []
        reuse = tf.get_variable_scope().reuse
        for _ in range(num_layers):
            cell = tf.contrib.rnn.BasicLSTMCell(size, forget_bias=0.0, reuse=reuse)
            if is_training and keep_prob < 1:
                # Unless you changed keep_prob, this won't actually execute -- this is a dropout wrapper for our LSTM unit
                # This is an optimization of the LSTM output, but is not needed at all
                cell = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=keep_prob)
            lstm_cells.append(cell)
           
        # By taking in the LSTM cells as parameters, the MultiRNNCell function junctions the LSTM units to the RNN units.
        # RNN cell composed sequentially of multiple simple cells.
        stacked_lstm = tf.contrib.rnn.MultiRNNCell(lstm_cells)

        # Define the initial state, i.e., the model state for the very first data point
        # It initialize the state of the LSTM memory. The memory state of the network is initialized with a vector of zeros and gets updated after reading each word.
        self._initial_state = stacked_lstm.zero_state(batch_size, tf.float32)

        ####################################################################
        # Creating the word embeddings and pointing them to the input data #
        ####################################################################
        with tf.device("/cpu:0"):
            # Create the embeddings for our input data. Size is hidden size.
            # Uses default variable initializer
            embedding = tf.get_variable("embedding", [vocab_size, size])  #[10000x200]
            # Define where to get the data for our embeddings from
            inputs = tf.nn.embedding_lookup(embedding, self._input_data)

        # Unless you changed keep_prob, this won't actually execute -- this is a dropout addition for our inputs
        # This is an optimization of the input processing and is not needed at all
        if is_training and keep_prob < 1:
            inputs = tf.nn.dropout(inputs, keep_prob)

        ############################################
        # Creating the input structure for our RNN #
        ############################################
        # Input structure is 20x[30x200]
        # Considering each word is represended by a 200 dimentional vector, and we have 30 batchs, we create 30 word-vectors of size [30xx2000]
        #inputs = [tf.squeeze(input_, [1]) for input_ in tf.split(1, num_steps, inputs)]
        # The input structure is fed from the embeddings, which are filled in by the input data
        # Feeding a batch of b sentences to a RNN:
        # In step 1,  first word of each of the b sentences (in a batch) is input in parallel.  
        # In step 2,  second word of each of the b sentences is input in parallel. 
        # The parallelism is only for efficiency.  
        # Each sentence in a batch is handled in parallel, but the network sees one word of a sentence at a time and does the computations accordingly. 
        # All the computations involving the words of all sentences in a batch at a given time step are done in parallel. 

        ####################################################################################################
        # Instanciating our RNN model and retrieving the structure for returning the outputs and the state #
        ####################################################################################################
        outputs, state = tf.nn.dynamic_rnn(stacked_lstm, inputs, initial_state=self._initial_state)

        #########################################################################
        # Creating a logistic unit to return the probability of the output word #
        #########################################################################
        output = tf.reshape(outputs, [-1, size])
        softmax_w = tf.get_variable("softmax_w", [size, vocab_size]) #[200x1000]
        softmax_b = tf.get_variable("softmax_b", [vocab_size]) #[1x1000]
        logits = tf.matmul(output, softmax_w) + softmax_b

        #########################################################################
        # Defining the loss and cost functions for the model's learning to work #
        #########################################################################
        loss = tf.contrib.legacy_seq2seq.sequence_loss_by_example([logits], [tf.reshape(self._targets, [-1])],
                                                      [tf.ones([batch_size * num_steps])])
        self._cost = cost = tf.reduce_sum(loss) / batch_size

        # Store the final state
        self._final_state = state

        #Everything after this point is relevant only for training
        if not is_training:
            return

        #################################################
        # Creating the Training Operation for our Model #
        #################################################
        # Create a variable for the learning rate
        self._lr = tf.Variable(0.0, trainable=False)
        # Get all TensorFlow variables marked as "trainable" (i.e. all of them except _lr, which we just created)
        tvars = tf.trainable_variables()
        # Define the gradient clipping threshold
        grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars), max_grad_norm)
        # Create the gradient descent optimizer with our learning rate
        optimizer = tf.train.GradientDescentOptimizer(self.lr)
        # Create the training TensorFlow Operation through our optimizer
        self._train_op = optimizer.apply_gradients(zip(grads, tvars))

    # Helper functions for our LSTM RNN class

    # Assign the learning rate for this model
    def assign_lr(self, session, lr_value):
        session.run(tf.assign(self.lr, lr_value))

    # Returns the input data for this model at a point in time
    @property
    def input_data(self):
        return self._input_data

    # Returns the targets for this model at a point in time
    @property
    def targets(self):
        return self._targets

    # Returns the initial state for this model
    @property
    def initial_state(self):
        return self._initial_state

    # Returns the defined Cost
    @property
    def cost(self):
        return self._cost

    # Returns the final state for this model
    @property
    def final_state(self):
        return self._final_state

    # Returns the current learning rate for this model
    @property
    def lr(self):
        return self._lr

    # Returns the training operation defined for this model
    @property
    def train_op(self):
        return self._train_op

With that, the actual structure of our Recurrent Neural Network with Long Short-Term Memory is finished. What remains for us to do is to actually create the methods to run through time -- that is, the run_epoch method to be run at each epoch and a main script which ties all of this together.

What our run_epoch method should do is take our input data and feed it to the relevant operations. This will return at the very least the current result for the cost function.


In [62]:
##########################################################################################################################
# run_epoch takes as parameters the current session, the model instance, the data to be fed, and the operation to be run #
##########################################################################################################################
def run_epoch(session, m, data, eval_op, verbose=False):

    #Define the epoch size based on the length of the data, batch size and the number of steps
    epoch_size = ((len(data) // m.batch_size) - 1) // m.num_steps
    start_time = time.time()
    costs = 0.0
    iters = 0
    #state = m.initial_state.eval()
    #m.initial_state = tf.convert_to_tensor(m.initial_state) 
    #state = m.initial_state.eval()
    state = session.run(m.initial_state)
    
    #For each step and data point
    for step, (x, y) in enumerate(reader.ptb_iterator(data, m.batch_size, m.num_steps)):
 
        #Evaluate and return cost, state by running cost, final_state and the function passed as parameter
        cost, state, _ = session.run([m.cost, m.final_state, eval_op],
                                     {m.input_data: x,
                                      m.targets: y,
                                      m.initial_state: state})
        
        #Add returned cost to costs (which keeps track of the total costs for this epoch)
        costs += cost
        
        #Add number of steps to iteration counter
        iters += m.num_steps

        if verbose and (step % 10) == 0:
            print("({:.2%})    Perplexity={:.3f}    Speed={:.0f} wps".format(
                    step * 1.0 / epoch_size, 
                    np.exp(costs / iters), 
                    iters * m.batch_size / (time.time() - start_time))
                 )

    # Returns the Perplexity rating for us to keep track of how the model is evolving
    return np.exp(costs / iters)

Now, we create the main method to tie everything together. The code here reads the data from the directory, using the reader helper module, and then trains and evaluates the model on both a testing and a validating subset of data.


In [63]:
# Reads the data and separates it into training data, validation data and testing data
raw_data = reader.ptb_raw_data(data_dir)
train_data, valid_data, test_data, _ = raw_data

In [64]:
#Initializes the Execution Graph and the Session
with tf.Graph().as_default(), tf.Session() as session:
    initializer = tf.random_uniform_initializer(-init_scale,init_scale)
    
    # Instantiates the model for training
    # tf.variable_scope add a prefix to the variables created with tf.get_variable
    with tf.variable_scope("model", reuse=None, initializer=initializer):
        m = PTBModel(is_training=True)
        
    # Reuses the trained parameters for the validation and testing models
    # They are different instances but use the same variables for weights and biases, 
    # they just don't change when data is input
    with tf.variable_scope("model", reuse=True, initializer=initializer):
        mvalid = PTBModel(is_training=False)
        mtest = PTBModel(is_training=False)

    #Initialize all variables
    tf.global_variables_initializer().run()
    
    # Set initial learning rate
    m.assign_lr(session=session, lr_value=learning_rate)
    
    for i in range(max_max_epoch):
        print("Epoch %d : Learning rate: %.3f" % (i + 1, session.run(m.lr)))
               
        # Run the loop for this epoch in the training model
        train_perplexity = run_epoch(session, m, train_data, m.train_op,
                                   verbose=True)
        print("Epoch %d : Train Perplexity: %.3f" % (i + 1, train_perplexity))
        
        # Run the loop for this epoch in the validation model
        valid_perplexity = run_epoch(session, mvalid, valid_data, tf.no_op())
        print("Epoch %d : Valid Perplexity: %.3f" % (i + 1, valid_perplexity))
        
        # Define the decay for the next epoch 
        lr_decay = decay * ((max_max_epoch - i) / max_max_epoch)
        
        # Set the decayed learning rate as the learning rate for the next epoch
        m.assign_lr(session, learning_rate * lr_decay)
    
    # Run the loop in the testing model to see how effective was our training
    test_perplexity = run_epoch(session, mtest, test_data, tf.no_op())
    
    print("Test Perplexity: %.3f" % test_perplexity)


Epoch 1 : Learning rate: 1.000
(0.00%)    Perplexity=10008.760    Speed=2745 wps
(0.65%)    Perplexity=8099.698    Speed=4014 wps
(1.29%)    Perplexity=3874.326    Speed=4117 wps
(1.94%)    Perplexity=2700.296    Speed=4155 wps
(2.58%)    Perplexity=2236.035    Speed=4176 wps
(3.23%)    Perplexity=1956.309    Speed=4186 wps
(3.87%)    Perplexity=1758.261    Speed=4197 wps
(4.52%)    Perplexity=1612.279    Speed=4201 wps
(5.16%)    Perplexity=1504.809    Speed=4208 wps
(5.81%)    Perplexity=1431.562    Speed=4211 wps
(6.46%)    Perplexity=1358.713    Speed=4214 wps
(7.10%)    Perplexity=1299.117    Speed=4216 wps
(7.75%)    Perplexity=1254.090    Speed=4217 wps
(8.39%)    Perplexity=1206.576    Speed=4219 wps
(9.04%)    Perplexity=1168.162    Speed=4220 wps
(9.68%)    Perplexity=1136.567    Speed=4221 wps
(10.33%)    Perplexity=1097.991    Speed=4198 wps
(10.97%)    Perplexity=1075.078    Speed=4180 wps
(11.62%)    Perplexity=1048.018    Speed=4176 wps
(12.27%)    Perplexity=1022.558    Speed=4129 wps
(12.91%)    Perplexity=999.630    Speed=4081 wps
(13.56%)    Perplexity=975.413    Speed=4017 wps
(14.20%)    Perplexity=950.739    Speed=4015 wps
(14.85%)    Perplexity=931.122    Speed=4019 wps
(15.49%)    Perplexity=910.874    Speed=4013 wps
(16.14%)    Perplexity=894.337    Speed=4000 wps
(16.79%)    Perplexity=875.736    Speed=3926 wps
(17.43%)    Perplexity=858.051    Speed=3845 wps
(18.08%)    Perplexity=843.910    Speed=3851 wps
(18.72%)    Perplexity=830.698    Speed=3859 wps
(19.37%)    Perplexity=817.435    Speed=3870 wps
(20.01%)    Perplexity=804.527    Speed=3878 wps
(20.66%)    Perplexity=789.208    Speed=3886 wps
(21.30%)    Perplexity=776.511    Speed=3894 wps
(21.95%)    Perplexity=764.009    Speed=3888 wps
(22.60%)    Perplexity=749.769    Speed=3838 wps
(23.24%)    Perplexity=738.858    Speed=3806 wps
(23.89%)    Perplexity=730.130    Speed=3806 wps
(24.53%)    Perplexity=720.693    Speed=3812 wps
(25.18%)    Perplexity=709.900    Speed=3822 wps
(25.82%)    Perplexity=698.234    Speed=3830 wps
(26.47%)    Perplexity=688.941    Speed=3833 wps
(27.11%)    Perplexity=678.943    Speed=3791 wps
(27.76%)    Perplexity=669.945    Speed=3765 wps
(28.41%)    Perplexity=659.722    Speed=3771 wps
(29.05%)    Perplexity=652.814    Speed=3771 wps
(29.70%)    Perplexity=644.648    Speed=3779 wps
(30.34%)    Perplexity=636.400    Speed=3788 wps
(30.99%)    Perplexity=629.712    Speed=3791 wps
(31.63%)    Perplexity=621.451    Speed=3751 wps
(32.28%)    Perplexity=613.263    Speed=3728 wps
(32.92%)    Perplexity=606.627    Speed=3734 wps
(33.57%)    Perplexity=598.318    Speed=3741 wps
(34.22%)    Perplexity=591.316    Speed=3746 wps
(34.86%)    Perplexity=584.528    Speed=3749 wps
(35.51%)    Perplexity=576.849    Speed=3753 wps
(36.15%)    Perplexity=567.809    Speed=3724 wps
(36.80%)    Perplexity=560.043    Speed=3701 wps
(37.44%)    Perplexity=554.562    Speed=3704 wps
(38.09%)    Perplexity=548.875    Speed=3707 wps
(38.73%)    Perplexity=543.302    Speed=3711 wps
(39.38%)    Perplexity=538.614    Speed=3715 wps
(40.03%)    Perplexity=534.055    Speed=3718 wps
(40.67%)    Perplexity=527.986    Speed=3721 wps
(41.32%)    Perplexity=522.890    Speed=3689 wps
(41.96%)    Perplexity=517.365    Speed=3668 wps
(42.61%)    Perplexity=513.109    Speed=3670 wps
(43.25%)    Perplexity=507.260    Speed=3668 wps
(43.90%)    Perplexity=502.878    Speed=3657 wps
(44.54%)    Perplexity=500.516    Speed=3645 wps
(45.19%)    Perplexity=497.019    Speed=3629 wps
(45.84%)    Perplexity=493.596    Speed=3579 wps
(46.48%)    Perplexity=489.869    Speed=3548 wps
(47.13%)    Perplexity=486.341    Speed=3530 wps
(47.77%)    Perplexity=481.968    Speed=3519 wps
(48.42%)    Perplexity=477.831    Speed=3493 wps
(49.06%)    Perplexity=473.399    Speed=3468 wps
(49.71%)    Perplexity=469.681    Speed=3461 wps
(50.36%)    Perplexity=466.603    Speed=3455 wps
(51.00%)    Perplexity=464.414    Speed=3449 wps
(51.65%)    Perplexity=461.455    Speed=3436 wps
(52.29%)    Perplexity=458.084    Speed=3406 wps
(52.94%)    Perplexity=455.205    Speed=3400 wps
(53.58%)    Perplexity=451.674    Speed=3396 wps
(54.23%)    Perplexity=448.082    Speed=3391 wps
(54.87%)    Perplexity=444.897    Speed=3385 wps
(55.52%)    Perplexity=442.302    Speed=3379 wps
(56.17%)    Perplexity=439.709    Speed=3352 wps
(56.81%)    Perplexity=437.226    Speed=3338 wps
(57.46%)    Perplexity=434.178    Speed=3335 wps
(58.10%)    Perplexity=430.994    Speed=3331 wps
(58.75%)    Perplexity=427.388    Speed=3327 wps
(59.39%)    Perplexity=423.634    Speed=3322 wps
(60.04%)    Perplexity=420.775    Speed=3307 wps
(60.68%)    Perplexity=418.026    Speed=3291 wps
(61.33%)    Perplexity=414.823    Speed=3290 wps
(61.98%)    Perplexity=411.586    Speed=3287 wps
(62.62%)    Perplexity=408.680    Speed=3282 wps
(63.27%)    Perplexity=406.041    Speed=3277 wps
(63.91%)    Perplexity=403.592    Speed=3268 wps
(64.56%)    Perplexity=401.071    Speed=3244 wps
(65.20%)    Perplexity=398.725    Speed=3231 wps
(65.85%)    Perplexity=397.402    Speed=3223 wps
(66.49%)    Perplexity=395.679    Speed=3216 wps
(67.14%)    Perplexity=393.856    Speed=3208 wps
(67.79%)    Perplexity=392.252    Speed=3192 wps
(68.43%)    Perplexity=390.085    Speed=3179 wps
(69.08%)    Perplexity=388.317    Speed=3176 wps
(69.72%)    Perplexity=386.673    Speed=3164 wps
(70.37%)    Perplexity=384.804    Speed=3150 wps
(71.01%)    Perplexity=382.838    Speed=3149 wps
(71.66%)    Perplexity=381.279    Speed=3146 wps
(72.30%)    Perplexity=379.147    Speed=3144 wps
(72.95%)    Perplexity=377.192    Speed=3140 wps
(73.60%)    Perplexity=375.482    Speed=3124 wps
(74.24%)    Perplexity=373.620    Speed=3122 wps
(74.89%)    Perplexity=371.637    Speed=3119 wps
(75.53%)    Perplexity=370.085    Speed=3118 wps
(76.18%)    Perplexity=368.206    Speed=3117 wps
(76.82%)    Perplexity=366.853    Speed=3115 wps
(77.47%)    Perplexity=365.430    Speed=3108 wps
(78.11%)    Perplexity=363.921    Speed=3093 wps
(78.76%)    Perplexity=362.338    Speed=3091 wps
(79.41%)    Perplexity=361.007    Speed=3089 wps
(80.05%)    Perplexity=359.436    Speed=3086 wps
(80.70%)    Perplexity=358.220    Speed=3071 wps
(81.34%)    Perplexity=356.745    Speed=3061 wps
(81.99%)    Perplexity=354.658    Speed=3061 wps
(82.63%)    Perplexity=353.116    Speed=3062 wps
(83.28%)    Perplexity=351.679    Speed=3062 wps
(83.93%)    Perplexity=349.639    Speed=3061 wps
(84.57%)    Perplexity=347.860    Speed=3060 wps
(85.22%)    Perplexity=345.891    Speed=3056 wps
(85.86%)    Perplexity=343.976    Speed=3043 wps
(86.51%)    Perplexity=342.140    Speed=3033 wps
(87.15%)    Perplexity=340.258    Speed=3032 wps
(87.80%)    Perplexity=338.919    Speed=3028 wps
(88.44%)    Perplexity=337.459    Speed=3016 wps
(89.09%)    Perplexity=336.455    Speed=3009 wps
(89.74%)    Perplexity=335.655    Speed=3008 wps
(90.38%)    Perplexity=334.353    Speed=3010 wps
(91.03%)    Perplexity=333.083    Speed=3010 wps
(91.67%)    Perplexity=331.208    Speed=3005 wps
(92.32%)    Perplexity=329.809    Speed=2995 wps
(92.96%)    Perplexity=328.492    Speed=2996 wps
(93.61%)    Perplexity=326.863    Speed=2995 wps
(94.25%)    Perplexity=325.508    Speed=2996 wps
(94.90%)    Perplexity=324.050    Speed=2994 wps
(95.55%)    Perplexity=322.746    Speed=2986 wps
(96.19%)    Perplexity=321.519    Speed=2977 wps
(96.84%)    Perplexity=320.622    Speed=2977 wps
(97.48%)    Perplexity=319.675    Speed=2976 wps
(98.13%)    Perplexity=318.654    Speed=2968 wps
(98.77%)    Perplexity=317.592    Speed=2963 wps
(99.42%)    Perplexity=316.679    Speed=2962 wps
Epoch 1 : Train Perplexity: 315.925
Epoch 1 : Valid Perplexity: 192.533
Epoch 2 : Learning rate: 0.500
(0.00%)    Perplexity=260.972    Speed=2275 wps
(0.65%)    Perplexity=206.411    Speed=2299 wps
(1.29%)    Perplexity=190.974    Speed=2257 wps
(1.94%)    Perplexity=180.341    Speed=2413 wps
(2.58%)    Perplexity=178.556    Speed=2405 wps
(3.23%)    Perplexity=179.633    Speed=2321 wps
(3.87%)    Perplexity=178.427    Speed=2346 wps
(4.52%)    Perplexity=174.278    Speed=2393 wps
(5.16%)    Perplexity=174.229    Speed=2465 wps
(5.81%)    Perplexity=176.668    Speed=2518 wps
(6.46%)    Perplexity=175.931    Speed=2490 wps
(7.10%)    Perplexity=177.729    Speed=2434 wps
(7.75%)    Perplexity=179.460    Speed=2434 wps
(8.39%)    Perplexity=179.178    Speed=2395 wps
(9.04%)    Perplexity=179.378    Speed=2361 wps
(9.68%)    Perplexity=179.966    Speed=2372 wps
(10.33%)    Perplexity=179.652    Speed=2334 wps
(10.97%)    Perplexity=180.878    Speed=2308 wps
(11.62%)    Perplexity=180.911    Speed=2305 wps
(12.27%)    Perplexity=180.077    Speed=2264 wps
(12.91%)    Perplexity=178.846    Speed=2268 wps
(13.56%)    Perplexity=177.876    Speed=2275 wps
(14.20%)    Perplexity=176.449    Speed=2292 wps
(14.85%)    Perplexity=176.358    Speed=2274 wps
(15.49%)    Perplexity=175.430    Speed=2256 wps
(16.14%)    Perplexity=175.375    Speed=2271 wps
(16.79%)    Perplexity=175.479    Speed=2289 wps
(17.43%)    Perplexity=175.085    Speed=2305 wps
(18.08%)    Perplexity=174.506    Speed=2319 wps
(18.72%)    Perplexity=174.926    Speed=2304 wps
(19.37%)    Perplexity=174.825    Speed=2305 wps
(20.01%)    Perplexity=174.429    Speed=2313 wps
(20.66%)    Perplexity=172.994    Speed=2319 wps
(21.30%)    Perplexity=172.299    Speed=2311 wps
(21.95%)    Perplexity=171.528    Speed=2296 wps
(22.60%)    Perplexity=170.509    Speed=2288 wps
(23.24%)    Perplexity=170.060    Speed=2296 wps
(23.89%)    Perplexity=170.266    Speed=2283 wps
(24.53%)    Perplexity=170.339    Speed=2282 wps
(25.18%)    Perplexity=169.704    Speed=2286 wps
(25.82%)    Perplexity=168.929    Speed=2270 wps
(26.47%)    Perplexity=168.673    Speed=2266 wps
(27.11%)    Perplexity=168.057    Speed=2268 wps
(27.76%)    Perplexity=167.598    Speed=2261 wps
(28.41%)    Perplexity=167.003    Speed=2255 wps
(29.05%)    Perplexity=166.981    Speed=2259 wps
(29.70%)    Perplexity=166.569    Speed=2252 wps
(30.34%)    Perplexity=165.933    Speed=2248 wps
(30.99%)    Perplexity=165.575    Speed=2253 wps
(31.63%)    Perplexity=165.151    Speed=2246 wps
(32.28%)    Perplexity=164.447    Speed=2249 wps
(32.92%)    Perplexity=164.168    Speed=2255 wps
(33.57%)    Perplexity=163.350    Speed=2250 wps
(34.22%)    Perplexity=162.883    Speed=2254 wps
(34.86%)    Perplexity=162.282    Speed=2260 wps
(35.51%)    Perplexity=161.395    Speed=2258 wps
(36.15%)    Perplexity=160.114    Speed=2257 wps
(36.80%)    Perplexity=159.097    Speed=2264 wps
(37.44%)    Perplexity=158.870    Speed=2260 wps
(38.09%)    Perplexity=158.498    Speed=2255 wps
(38.73%)    Perplexity=158.236    Speed=2262 wps
(39.38%)    Perplexity=158.177    Speed=2267 wps
(40.03%)    Perplexity=158.073    Speed=2262 wps
(40.67%)    Perplexity=157.642    Speed=2271 wps
(41.32%)    Perplexity=157.352    Speed=2280 wps
(41.96%)    Perplexity=156.894    Speed=2288 wps
(42.61%)    Perplexity=156.559    Speed=2296 wps
(43.25%)    Perplexity=156.025    Speed=2292 wps
(43.90%)    Perplexity=155.809    Speed=2293 wps
(44.54%)    Perplexity=156.223    Speed=2300 wps
(45.19%)    Perplexity=156.338    Speed=2294 wps
(45.84%)    Perplexity=156.373    Speed=2292 wps
(46.48%)    Perplexity=156.166    Speed=2296 wps
(47.13%)    Perplexity=155.962    Speed=2293 wps
(47.77%)    Perplexity=155.591    Speed=2293 wps
(48.42%)    Perplexity=155.092    Speed=2299 wps
(49.06%)    Perplexity=154.775    Speed=2298 wps
(49.71%)    Perplexity=154.608    Speed=2295 wps
(50.36%)    Perplexity=154.543    Speed=2299 wps
(51.00%)    Perplexity=154.802    Speed=2302 wps
(51.65%)    Perplexity=154.758    Speed=2297 wps
(52.29%)    Perplexity=154.531    Speed=2301 wps
(52.94%)    Perplexity=154.391    Speed=2307 wps
(53.58%)    Perplexity=154.058    Speed=2306 wps
(54.23%)    Perplexity=153.763    Speed=2301 wps
(54.87%)    Perplexity=153.559    Speed=2305 wps
(55.52%)    Perplexity=153.536    Speed=2311 wps
(56.17%)    Perplexity=153.378    Speed=2318 wps
(56.81%)    Perplexity=153.336    Speed=2324 wps
(57.46%)    Perplexity=153.070    Speed=2332 wps
(58.10%)    Perplexity=152.795    Speed=2332 wps
(58.75%)    Perplexity=152.196    Speed=2330 wps
(59.39%)    Perplexity=151.562    Speed=2335 wps
(60.04%)    Perplexity=151.197    Speed=2337 wps
(60.68%)    Perplexity=150.852    Speed=2333 wps
(61.33%)    Perplexity=150.483    Speed=2338 wps
(61.98%)    Perplexity=149.942    Speed=2341 wps
(62.62%)    Perplexity=149.507    Speed=2338 wps
(63.27%)    Perplexity=149.248    Speed=2342 wps
(63.91%)    Perplexity=149.033    Speed=2347 wps
(64.56%)    Perplexity=148.742    Speed=2352 wps
(65.20%)    Perplexity=148.422    Speed=2358 wps
(65.85%)    Perplexity=148.545    Speed=2359 wps
(66.49%)    Perplexity=148.507    Speed=2354 wps
(67.14%)    Perplexity=148.505    Speed=2353 wps
(67.79%)    Perplexity=148.488    Speed=2356 wps
(68.43%)    Perplexity=148.296    Speed=2353 wps
(69.08%)    Perplexity=148.208    Speed=2350 wps
(69.72%)    Perplexity=148.181    Speed=2352 wps
(70.37%)    Perplexity=147.992    Speed=2356 wps
(71.01%)    Perplexity=147.829    Speed=2353 wps
(71.66%)    Perplexity=147.838    Speed=2354 wps
(72.30%)    Perplexity=147.588    Speed=2359 wps
(72.95%)    Perplexity=147.336    Speed=2364 wps
(73.60%)    Perplexity=147.086    Speed=2369 wps
(74.24%)    Perplexity=146.815    Speed=2373 wps
(74.89%)    Perplexity=146.519    Speed=2370 wps
(75.53%)    Perplexity=146.413    Speed=2372 wps
(76.18%)    Perplexity=146.172    Speed=2377 wps
(76.82%)    Perplexity=146.203    Speed=2381 wps
(77.47%)    Perplexity=146.184    Speed=2386 wps
(78.11%)    Perplexity=146.123    Speed=2390 wps
(78.76%)    Perplexity=145.998    Speed=2393 wps
(79.41%)    Perplexity=146.028    Speed=2389 wps
(80.05%)    Perplexity=145.933    Speed=2387 wps
(80.70%)    Perplexity=145.923    Speed=2390 wps
(81.34%)    Perplexity=145.830    Speed=2394 wps
(81.99%)    Perplexity=145.522    Speed=2390 wps
(82.63%)    Perplexity=145.367    Speed=2390 wps
(83.28%)    Perplexity=145.211    Speed=2393 wps
(83.93%)    Perplexity=144.794    Speed=2392 wps
(84.57%)    Perplexity=144.423    Speed=2390 wps
(85.22%)    Perplexity=143.854    Speed=2395 wps
(85.86%)    Perplexity=143.468    Speed=2399 wps
(86.51%)    Perplexity=143.066    Speed=2395 wps
(87.15%)    Perplexity=142.697    Speed=2399 wps
(87.80%)    Perplexity=142.529    Speed=2403 wps
(88.44%)    Perplexity=142.305    Speed=2407 wps
(89.09%)    Perplexity=142.278    Speed=2403 wps
(89.74%)    Perplexity=142.357    Speed=2407 wps
(90.38%)    Perplexity=142.208    Speed=2411 wps
(91.03%)    Perplexity=142.087    Speed=2415 wps
(91.67%)    Perplexity=141.731    Speed=2409 wps
(92.32%)    Perplexity=141.539    Speed=2409 wps
(92.96%)    Perplexity=141.373    Speed=2411 wps
(93.61%)    Perplexity=141.020    Speed=2409 wps
(94.25%)    Perplexity=140.789    Speed=2408 wps
(94.90%)    Perplexity=140.528    Speed=2411 wps
(95.55%)    Perplexity=140.319    Speed=2413 wps
(96.19%)    Perplexity=140.145    Speed=2410 wps
(96.84%)    Perplexity=140.117    Speed=2410 wps
(97.48%)    Perplexity=140.045    Speed=2413 wps
(98.13%)    Perplexity=139.976    Speed=2411 wps
(98.77%)    Perplexity=139.864    Speed=2409 wps
(99.42%)    Perplexity=139.831    Speed=2412 wps
Epoch 2 : Train Perplexity: 139.777
Epoch 2 : Valid Perplexity: 149.756
Epoch 3 : Learning rate: 0.462
(0.00%)    Perplexity=192.778    Speed=2004 wps
(0.65%)    Perplexity=146.506    Speed=2713 wps
(1.29%)    Perplexity=136.844    Speed=2378 wps
(1.94%)    Perplexity=128.000    Speed=2247 wps
(2.58%)    Perplexity=125.233    Speed=2308 wps
(3.23%)    Perplexity=127.019    Speed=2398 wps
(3.87%)    Perplexity=126.300    Speed=2334 wps
(4.52%)    Perplexity=124.113    Speed=2365 wps
(5.16%)    Perplexity=125.028    Speed=2432 wps
(5.81%)    Perplexity=127.201    Speed=2489 wps
(6.46%)    Perplexity=127.139    Speed=2421 wps
(7.10%)    Perplexity=129.245    Speed=2473 wps
(7.75%)    Perplexity=131.105    Speed=2512 wps
(8.39%)    Perplexity=131.459    Speed=2543 wps
(9.04%)    Perplexity=131.869    Speed=2559 wps
(9.68%)    Perplexity=132.644    Speed=2514 wps
(10.33%)    Perplexity=132.720    Speed=2526 wps
(10.97%)    Perplexity=133.832    Speed=2544 wps
(11.62%)    Perplexity=133.898    Speed=2508 wps
(12.27%)    Perplexity=133.384    Speed=2475 wps
(12.91%)    Perplexity=132.706    Speed=2480 wps
(13.56%)    Perplexity=132.272    Speed=2500 wps
(14.20%)    Perplexity=131.372    Speed=2508 wps
(14.85%)    Perplexity=131.412    Speed=2484 wps
(15.49%)    Perplexity=130.870    Speed=2482 wps
(16.14%)    Perplexity=130.913    Speed=2495 wps
(16.79%)    Perplexity=131.194    Speed=2474 wps
(17.43%)    Perplexity=131.034    Speed=2474 wps
(18.08%)    Perplexity=130.669    Speed=2491 wps
(18.72%)    Perplexity=131.151    Speed=2509 wps
(19.37%)    Perplexity=131.185    Speed=2489 wps
(20.01%)    Perplexity=130.975    Speed=2504 wps
(20.66%)    Perplexity=130.017    Speed=2519 wps
(21.30%)    Perplexity=129.511    Speed=2531 wps
(21.95%)    Perplexity=129.001    Speed=2531 wps
(22.60%)    Perplexity=128.324    Speed=2512 wps
(23.24%)    Perplexity=128.052    Speed=2516 wps
(23.89%)    Perplexity=128.287    Speed=2524 wps
(24.53%)    Perplexity=128.439    Speed=2509 wps
(25.18%)    Perplexity=128.075    Speed=2509 wps
(25.82%)    Perplexity=127.619    Speed=2517 wps
(26.47%)    Perplexity=127.524    Speed=2517 wps
(27.11%)    Perplexity=127.080    Speed=2502 wps
(27.76%)    Perplexity=126.818    Speed=2502 wps
(28.41%)    Perplexity=126.466    Speed=2509 wps
(29.05%)    Perplexity=126.526    Speed=2500 wps
(29.70%)    Perplexity=126.159    Speed=2496 wps
(30.34%)    Perplexity=125.667    Speed=2506 wps
(30.99%)    Perplexity=125.384    Speed=2493 wps
(31.63%)    Perplexity=125.162    Speed=2503 wps
(32.28%)    Perplexity=124.733    Speed=2512 wps
(32.92%)    Perplexity=124.613    Speed=2521 wps
(33.57%)    Perplexity=124.068    Speed=2510 wps
(34.22%)    Perplexity=123.794    Speed=2504 wps
(34.86%)    Perplexity=123.352    Speed=2500 wps
(35.51%)    Perplexity=122.689    Speed=2484 wps
(36.15%)    Perplexity=121.719    Speed=2481 wps
(36.80%)    Perplexity=120.971    Speed=2488 wps
(37.44%)    Perplexity=120.888    Speed=2486 wps
(38.09%)    Perplexity=120.695    Speed=2475 wps
(38.73%)    Perplexity=120.551    Speed=2474 wps
(39.38%)    Perplexity=120.579    Speed=2480 wps
(40.03%)    Perplexity=120.556    Speed=2475 wps
(40.67%)    Perplexity=120.314    Speed=2466 wps
(41.32%)    Perplexity=120.173    Speed=2471 wps
(41.96%)    Perplexity=119.912    Speed=2473 wps
(42.61%)    Perplexity=119.692    Speed=2465 wps
(43.25%)    Perplexity=119.372    Speed=2471 wps
(43.90%)    Perplexity=119.267    Speed=2478 wps
(44.54%)    Perplexity=119.652    Speed=2469 wps
(45.19%)    Perplexity=119.843    Speed=2476 wps
(45.84%)    Perplexity=119.890    Speed=2483 wps
(46.48%)    Perplexity=119.766    Speed=2475 wps
(47.13%)    Perplexity=119.662    Speed=2481 wps
(47.77%)    Perplexity=119.439    Speed=2487 wps
(48.42%)    Perplexity=119.098    Speed=2494 wps
(49.06%)    Perplexity=118.949    Speed=2486 wps
(49.71%)    Perplexity=118.901    Speed=2485 wps
(50.36%)    Perplexity=118.916    Speed=2480 wps
(51.00%)    Perplexity=119.188    Speed=2484 wps
(51.65%)    Perplexity=119.189    Speed=2480 wps
(52.29%)    Perplexity=119.073    Speed=2473 wps
(52.94%)    Perplexity=119.021    Speed=2479 wps
(53.58%)    Perplexity=118.835    Speed=2472 wps
(54.23%)    Perplexity=118.656    Speed=2477 wps
(54.87%)    Perplexity=118.579    Speed=2474 wps
(55.52%)    Perplexity=118.646    Speed=2477 wps
(56.17%)    Perplexity=118.572    Speed=2468 wps
(56.81%)    Perplexity=118.604    Speed=2474 wps
(57.46%)    Perplexity=118.441    Speed=2467 wps
(58.10%)    Perplexity=118.268    Speed=2472 wps
(58.75%)    Perplexity=117.828    Speed=2460 wps
(59.39%)    Perplexity=117.362    Speed=2461 wps
(60.04%)    Perplexity=117.087    Speed=2463 wps
(60.68%)    Perplexity=116.839    Speed=2467 wps
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-64-1fb844799b2c> in <module>()
     26         # Run the loop for this epoch in the training model
     27         train_perplexity = run_epoch(session, m, train_data, m.train_op,
---> 28                                    verbose=True)
     29         print("Epoch %d : Train Perplexity: %.3f" % (i + 1, train_perplexity))
     30 

<ipython-input-62-1b01bc8c3370> in run_epoch(session, m, data, eval_op, verbose)
     21                                      {m.input_data: x,
     22                                       m.targets: y,
---> 23                                       m.initial_state: state})
     24 
     25         #Add returned cost to costs (which keeps track of the total costs for this epoch)

/home/santi/miniconda3/envs/data_science/lib/python3.5/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    776     try:
    777       result = self._run(None, fetches, feed_dict, options_ptr,
--> 778                          run_metadata_ptr)
    779       if run_metadata:
    780         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/home/santi/miniconda3/envs/data_science/lib/python3.5/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
    980     if final_fetches or final_targets:
    981       results = self._do_run(handle, final_targets, final_fetches,
--> 982                              feed_dict_string, options, run_metadata)
    983     else:
    984       results = []

/home/santi/miniconda3/envs/data_science/lib/python3.5/site-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1030     if handle is None:
   1031       return self._do_call(_run_fn, self._session, feed_dict, fetch_list,
-> 1032                            target_list, options, run_metadata)
   1033     else:
   1034       return self._do_call(_prun_fn, self._session, handle, feed_dict,

/home/santi/miniconda3/envs/data_science/lib/python3.5/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
   1037   def _do_call(self, fn, *args):
   1038     try:
-> 1039       return fn(*args)
   1040     except errors.OpError as e:
   1041       message = compat.as_text(e.message)

/home/santi/miniconda3/envs/data_science/lib/python3.5/site-packages/tensorflow/python/client/session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
   1019         return tf_session.TF_Run(session, options,
   1020                                  feed_dict, fetch_list, target_list,
-> 1021                                  status, run_metadata)
   1022 
   1023     def _prun_fn(session, handle, feed_dict, fetch_list):

KeyboardInterrupt: 

As you can see, the model's perplexity rating drops very quickly after a few iterations. As was elaborated before, lower Perplexity means that the model is more certain about its prediction. As such, we can be sure that this model is performing well!


This is the end of the Applying Recurrent Neural Networks to Text Processing notebook. Hopefully you now have a better understanding of Recurrent Neural Networks and how to implement one utilizing TensorFlow. Thank you for reading this notebook, and good luck on your studies.


In [ ]: