Sentiment Analysis with an RNN

In this notebook, you'll implement a recurrent neural network that performs sentiment analysis. Using an RNN rather than a feedfoward network is more accurate since we can include information about the sequence of words. Here we'll use a dataset of movie reviews, accompanied by labels.

The architecture for this network is shown below.

Here, we'll pass in words to an embedding layer. We need an embedding layer because we have tens of thousands of words, so we'll need a more efficient representation for our input data than one-hot encoded vectors. You should have seen this before from the word2vec lesson. You can actually train up an embedding with word2vec and use it here. But it's good enough to just have an embedding layer and let the network learn the embedding table on it's own.

From the embedding layer, the new representations will be passed to LSTM cells. These will add recurrent connections to the network so we can include information about the sequence of words in the data. Finally, the LSTM cells will go to a sigmoid output layer here. We're using the sigmoid because we're trying to predict if this text has positive or negative sentiment. The output layer will just be a single unit then, with a sigmoid activation function.

We don't care about the sigmoid outputs except for the very last one, we can ignore the rest. We'll calculate the cost from the output of the last step and the training label.


In [1]:
import numpy as np
import tensorflow as tf

In [2]:
with open('../sentiment-network/reviews.txt', 'r') as f:
    reviews = f.read()
with open('../sentiment-network/labels.txt', 'r') as f:
    labels = f.read()

In [3]:
reviews[:2000]


Out[3]:
'bromwell high is a cartoon comedy . it ran at the same time as some other programs about school life  such as  teachers  . my   years in the teaching profession lead me to believe that bromwell high  s satire is much closer to reality than is  teachers  . the scramble to survive financially  the insightful students who can see right through their pathetic teachers  pomp  the pettiness of the whole situation  all remind me of the schools i knew and their students . when i saw the episode in which a student repeatedly tried to burn down the school  i immediately recalled . . . . . . . . . at . . . . . . . . . . high . a classic line inspector i  m here to sack one of your teachers . student welcome to bromwell high . i expect that many adults of my age think that bromwell high is far fetched . what a pity that it isn  t   \nstory of a man who has unnatural feelings for a pig . starts out with a opening scene that is a terrific example of absurd comedy . a formal orchestra audience is turned into an insane  violent mob by the crazy chantings of it  s singers . unfortunately it stays absurd the whole time with no general narrative eventually making it just too off putting . even those from the era should be turned off . the cryptic dialogue would make shakespeare seem easy to a third grader . on a technical level it  s better than you might think with some good cinematography by future great vilmos zsigmond . future stars sally kirkland and frederic forrest can be seen briefly .  \nhomelessness  or houselessness as george carlin stated  has been an issue for years but never a plan to help those on the street that were once considered human who did everything from going to school  work  or vote for the matter . most people think of the homeless as just a lost cause while worrying about things such as racism  the war on iraq  pressuring kids to succeed  technology  the elections  inflation  or worrying if they  ll be next to end up on the streets .  br    br   but what if y'

Data preprocessing

The first step when building a neural network model is getting your data into the proper form to feed into the network. Since we're using embedding layers, we'll need to encode each word with an integer. We'll also want to clean it up a bit.

You can see an example of the reviews data above. We'll want to get rid of those periods. Also, you might notice that the reviews are delimited with newlines \n. To deal with those, I'm going to split the text into each review using \n as the delimiter. Then I can combined all the reviews back together into one big string.

First, let's remove all punctuation. Then get all the text without the newlines and split it into individual words.


In [4]:
from string import punctuation
all_text = ''.join([c for c in reviews if c not in punctuation])
reviews = all_text.split('\n')

all_text = ' '.join(reviews)
words = all_text.split()

In [5]:
all_text[:2000]


Out[5]:
'bromwell high is a cartoon comedy  it ran at the same time as some other programs about school life  such as  teachers   my   years in the teaching profession lead me to believe that bromwell high  s satire is much closer to reality than is  teachers   the scramble to survive financially  the insightful students who can see right through their pathetic teachers  pomp  the pettiness of the whole situation  all remind me of the schools i knew and their students  when i saw the episode in which a student repeatedly tried to burn down the school  i immediately recalled          at           high  a classic line inspector i  m here to sack one of your teachers  student welcome to bromwell high  i expect that many adults of my age think that bromwell high is far fetched  what a pity that it isn  t    story of a man who has unnatural feelings for a pig  starts out with a opening scene that is a terrific example of absurd comedy  a formal orchestra audience is turned into an insane  violent mob by the crazy chantings of it  s singers  unfortunately it stays absurd the whole time with no general narrative eventually making it just too off putting  even those from the era should be turned off  the cryptic dialogue would make shakespeare seem easy to a third grader  on a technical level it  s better than you might think with some good cinematography by future great vilmos zsigmond  future stars sally kirkland and frederic forrest can be seen briefly    homelessness  or houselessness as george carlin stated  has been an issue for years but never a plan to help those on the street that were once considered human who did everything from going to school  work  or vote for the matter  most people think of the homeless as just a lost cause while worrying about things such as racism  the war on iraq  pressuring kids to succeed  technology  the elections  inflation  or worrying if they  ll be next to end up on the streets   br    br   but what if you were given a bet to live on the st'

In [6]:
words[:100]


Out[6]:
['bromwell',
 'high',
 'is',
 'a',
 'cartoon',
 'comedy',
 'it',
 'ran',
 'at',
 'the',
 'same',
 'time',
 'as',
 'some',
 'other',
 'programs',
 'about',
 'school',
 'life',
 'such',
 'as',
 'teachers',
 'my',
 'years',
 'in',
 'the',
 'teaching',
 'profession',
 'lead',
 'me',
 'to',
 'believe',
 'that',
 'bromwell',
 'high',
 's',
 'satire',
 'is',
 'much',
 'closer',
 'to',
 'reality',
 'than',
 'is',
 'teachers',
 'the',
 'scramble',
 'to',
 'survive',
 'financially',
 'the',
 'insightful',
 'students',
 'who',
 'can',
 'see',
 'right',
 'through',
 'their',
 'pathetic',
 'teachers',
 'pomp',
 'the',
 'pettiness',
 'of',
 'the',
 'whole',
 'situation',
 'all',
 'remind',
 'me',
 'of',
 'the',
 'schools',
 'i',
 'knew',
 'and',
 'their',
 'students',
 'when',
 'i',
 'saw',
 'the',
 'episode',
 'in',
 'which',
 'a',
 'student',
 'repeatedly',
 'tried',
 'to',
 'burn',
 'down',
 'the',
 'school',
 'i',
 'immediately',
 'recalled',
 'at',
 'high']

Encoding the words

The embedding lookup requires that we pass in integers to our network. The easiest way to do this is to create dictionaries that map the words in the vocabulary to integers. Then we can convert each of our reviews into integers so they can be passed into the network.

Exercise: Now you're going to encode the words with integers. Build a dictionary that maps words to integers. Later we're going to pad our input vectors with zeros, so make sure the integers start at 1, not 0. Also, convert the reviews to integers and store the reviews in a new list called reviews_ints.


In [7]:
from collections import Counter
counts = Counter(words)
vocab = sorted(counts, key=counts.get, reverse=True)
vocab_to_int = {word: ii for ii, word in enumerate(vocab, 1)}

reviews_ints = []
for each in reviews:
    reviews_ints.append([vocab_to_int[word] for word in each.split()])

Encoding the labels

Our labels are "positive" or "negative". To use these labels in our network, we need to convert them to 0 and 1.

Exercise: Convert labels from positive and negative to 1 and 0, respectively.


In [8]:
labels = labels.split('\n')
labels = np.array([1 if each == 'positive' else 0 for each in labels])

In [9]:
review_lens = Counter([len(x) for x in reviews_ints])
print("Zero-length reviews: {}".format(review_lens[0]))
print("Maximum review length: {}".format(max(review_lens)))


Zero-length reviews: 1
Maximum review length: 2514

Okay, a couple issues here. We seem to have one review with zero length. And, the maximum review length is way too many steps for our RNN. Let's truncate to 200 steps. For reviews shorter than 200, we'll pad with 0s. For reviews longer than 200, we can truncate them to the first 200 characters.

Exercise: First, remove the review with zero length from the reviews_ints list.


In [10]:
non_zero_idx = [ii for ii, review in enumerate(reviews_ints) if len(review) != 0]
len(non_zero_idx)


Out[10]:
25000

In [11]:
reviews_ints[-1]


Out[11]:
[]

Turns out its the final review that has zero length. But that might not always be the case, so let's make it more general.


In [12]:
reviews_ints = [reviews_ints[ii] for ii in non_zero_idx]
labels = np.array([labels[ii] for ii in non_zero_idx])

Exercise: Now, create an array features that contains the data we'll pass to the network. The data should come from review_ints, since we want to feed integers to the network. Each row should be 200 elements long. For reviews shorter than 200 words, left pad with 0s. That is, if the review is ['best', 'movie', 'ever'], [117, 18, 128] as integers, the row will look like [0, 0, 0, ..., 0, 117, 18, 128]. For reviews longer than 200, use on the first 200 words as the feature vector.

This isn't trivial and there are a bunch of ways to do this. But, if you're going to be building your own deep learning networks, you're going to have to get used to preparing your data.


In [13]:
seq_len = 200
features = np.zeros((len(reviews_ints), seq_len), dtype=int)
for i, row in enumerate(reviews_ints):
    features[i, -len(row):] = np.array(row)[:seq_len]

In [14]:
features[:10,:100]


Out[14]:
array([[    0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0, 21025,   308,     6,
            3,  1050,   207,     8,  2138,    32,     1,   171,    57,
           15,    49,    81,  5785,    44,   382,   110,   140,    15,
         5194,    60,   154,     9,     1,  4975,  5852,   475,    71,
            5,   260,    12, 21025,   308,    13,  1978,     6,    74,
         2395],
       [    0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,    63,     4,     3,   125,
           36,    47,  7472,  1395,    16,     3,  4181,   505,    45,
           17],
       [22382,    42, 46418,    15,   706, 17139,  3389,    47,    77,
           35,  1819,    16,   154,    19,   114,     3,  1305,     5,
          336,   147,    22,     1,   857,    12,    70,   281,  1168,
          399,    36,   120,   283,    38,   169,     5,   382,   158,
           42,  2269,    16,     1,   541,    90,    78,   102,     4,
            1,  3244,    15,    43,     3,   407,  1068,   136,  8055,
           44,   182,   140,    15,  3043,     1,   320,    22,  4818,
        26224,   346,     5,  3090,  2092,     1, 18839, 17939,    42,
         8055,    46,    33,   236,    29,   370,     5,   130,    56,
           22,     1,  1928,     7,     7,    19,    48,    46,    21,
           70,   344,     3,  2099,     5,   408,    22,     1,  1928,
           16],
       [ 4505,   505,    15,     3,  3342,   162,  8312,  1652,     6,
         4819,    56,    17,  4504,  5616,   140, 11725,     5,   996,
         4919,  2933,  4462,   566,  1201,    36,     6,  1518,    96,
            3,   744,     4, 26225,    13,     5,    27,  3461,     9,
        10625,     4,     8,   111,  3013,     5,     1,  1027,    15,
            3,  4390,    82,    22,  2049,     6,  4462,   538,  2764,
         7073, 37443,    41,   463,     1,  8312, 46419,   302,   123,
           15,  4221,    19,  1667,   922,     1,  1652,     6,  6129,
        19871,    34,     1,   980,  1751, 22383,   646, 24104,    27,
          106, 11726,    13, 14045, 15097, 17940,  2457,   466, 21027,
           36,  3266,     1,  6365,  1020,    45,    17,  2695,  2499,
           33],
       [    0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,   520,   119,   113,    34,
        16372,  1816,  3737,   117,   885, 21030,   721,    10,    28,
          124,   108,     2,   115,   137,     9,  1623,  7691,    26,
          330,     5,   589,     1,  6130,    22,   386,     6,     3,
          349,    15,    50,    15,   231,     9,  7473, 11399,     1,
          191,    22,  8966,     6,    82,   880,   101,   111,  3584,
            4],
       [    0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
           11,    20,  3637,   141,    10,   422,    23,   272,    60,
         4355,    22,    32,    84,  3286,    22,     1,   172,     4,
            1,   952,   507,    11,  4977,  5361,     5,   574,     4,
         1155,    54,    53,  5304,     1,   261,    17,    41,   952,
          125,    59,     1,   711,   137,   379,   626,    15,   111,
         1509],
       [    0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,    11,     6,   692,     1,    90,
         2156,    20, 11728,     1,  2818,  5195,   249,    92,  3006,
            8,   126,    24,   200,     3,   802,   634,     4, 22382,
         1001],
       [    0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,   786,   295,    10,   122,    11,     6,   419,
            5,    29,    35,   482,    20,    19,  1281,    33,   142,
           28,  2657,    45,  1840,    32,     1,  2778,    37,    78,
           97,  2436,    67,  3950,    45,     2,    24,   105,   256,
            1,   134,  1571,     2, 12399,   451,    14,   319,    11,
           63,     6,    98,  1321,     5,   105,     1,  3767,     4,
            3],
       [    0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,    11,     6,
           24,     1,   779,  3687,  2818,    20,     8,    14,    74,
          325,  2730,    73,    90,     4,    27,    99,     2,   165,
           68],
       [   54,    10,    14,   116,    60,   798,   552,    71,   364,
            5,     1,   730,     5,    66,  8057,     8,    14,    30,
            4,   109,    99,    10,   293,    17,    60,   798,    19,
           11,    14,     1,    64,    30,    69,  2500,    45,     4,
          234,    93,    10,    68,   114,   108,  8057,   363,    43,
         1009,     2,    10,    97,    28,  1431,    45,     1,   357,
            4,    60,   110,   205,     8,    48,     3,  1929, 10880,
            2,  2124,   354,   412,     4,    13,  6609,     2,  2974,
         5148,  2125,  1366,     6,    30,     4,    60,   502,   876,
           19,  8057,     6,    34,   227,     1,   247,   412,     4,
          582,     4,    27,   599,     9,     1, 13586,   396,     4,
        14047]])

Training, Validation, Test

With our data in nice shape, we'll split it into training, validation, and test sets.

Exercise: Create the training, validation, and test sets here. You'll need to create sets for the features and the labels, train_x and train_y for example. Define a split fraction, split_frac as the fraction of data to keep in the training set. Usually this is set to 0.8 or 0.9. The rest of the data will be split in half to create the validation and testing data.


In [15]:
split_frac = 0.8
split_idx = int(len(features)*0.8)
train_x, val_x = features[:split_idx], features[split_idx:]
train_y, val_y = labels[:split_idx], labels[split_idx:]

test_idx = int(len(val_x)*0.5)
val_x, test_x = val_x[:test_idx], val_x[test_idx:]
val_y, test_y = val_y[:test_idx], val_y[test_idx:]

print("\t\t\tFeature Shapes:")
print("Train set: \t\t{}".format(train_x.shape), 
      "\nValidation set: \t{}".format(val_x.shape),
      "\nTest set: \t\t{}".format(test_x.shape))


			Feature Shapes:
Train set: 		(20000, 200) 
Validation set: 	(2500, 200) 
Test set: 		(2500, 200)

With train, validation, and text fractions of 0.8, 0.1, 0.1, the final shapes should look like:

                    Feature Shapes:
Train set:       (20000, 200) 
Validation set:     (2500, 200) 
Test set:         (2500, 200)

Build the graph

Here, we'll build the graph. First up, defining the hyperparameters.

  • lstm_size: Number of units in the hidden layers in the LSTM cells. Usually larger is better performance wise. Common values are 128, 256, 512, etc.
  • lstm_layers: Number of LSTM layers in the network. I'd start with 1, then add more if I'm underfitting.
  • batch_size: The number of reviews to feed the network in one training pass. Typically this should be set as high as you can go without running out of memory.
  • learning_rate: Learning rate

In [16]:
lstm_size = 256
lstm_layers = 1
batch_size = 100
learning_rate = 0.001

tf.reset_default_graph()

For the network itself, we'll be passing in our 200 element long review vectors. Each batch will be batch_size vectors. We'll also be using dropout on the LSTM layer, so we'll make a placeholder for the keep probability.

Exercise: Create the inputs_, labels_, and drop out keep_prob placeholders using tf.placeholder. labels_ needs to be two-dimensional to work with some functions later. Since keep_prob is a scalar (a 0-dimensional tensor), you shouldn't provide a size to tf.placeholder.


In [17]:
n_words = len(vocab_to_int) + 1 # Adding 1 because we use 0's for padding, dictionary started at 1

# Create the graph object
graph = tf.Graph()
# Add nodes to the graph
with graph.as_default():
    inputs_ = tf.placeholder(tf.int32, [None, None], name='inputs')
    labels_ = tf.placeholder(tf.int32, [None, None], name='labels')
    keep_prob = tf.placeholder(tf.float32, name='keep_prob')

Embedding

Now we'll add an embedding layer. We need to do this because there are 74000 words in our vocabulary. It is massively inefficient to one-hot encode our classes here. You should remember dealing with this problem from the word2vec lesson. Instead of one-hot encoding, we can have an embedding layer and use that layer as a lookup table. You could train an embedding layer using word2vec, then load it here. But, it's fine to just make a new layer and let the network learn the weights.

Exercise: Create the embedding lookup matrix as a tf.Variable. Use that embedding matrix to get the embedded vectors to pass to the LSTM cell with tf.nn.embedding_lookup. This function takes the embedding matrix and an input tensor, such as the review vectors. Then, it'll return another tensor with the embedded vectors. So, if the embedding layer as 200 units, the function will return a tensor with size [batch_size, 200].


In [18]:
# Size of the embedding vectors (number of units in the embedding layer)
embed_size = 300 

with graph.as_default():
    embedding = tf.Variable(tf.random_uniform((n_words, embed_size), -1, 1))
    embed = tf.nn.embedding_lookup(embedding, inputs_)

LSTM cell

Next, we'll create our LSTM cells to use in the recurrent network (TensorFlow documentation). Here we are just defining what the cells look like. This isn't actually building the graph, just defining the type of cells we want in our graph.

To create a basic LSTM cell for the graph, you'll want to use tf.contrib.rnn.BasicLSTMCell. Looking at the function documentation:

tf.contrib.rnn.BasicLSTMCell(num_units, forget_bias=1.0, input_size=None, state_is_tuple=True, activation=<function tanh at 0x109f1ef28>)

you can see it takes a parameter called num_units, the number of units in the cell, called lstm_size in this code. So then, you can write something like

lstm = tf.contrib.rnn.BasicLSTMCell(num_units)

to create an LSTM cell with num_units. Next, you can add dropout to the cell with tf.contrib.rnn.DropoutWrapper. This just wraps the cell in another cell, but with dropout added to the inputs and/or outputs. It's a really convenient way to make your network better with almost no effort! So you'd do something like

drop = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=keep_prob)

Most of the time, you're network will have better performance with more layers. That's sort of the magic of deep learning, adding more layers allows the network to learn really complex relationships. Again, there is a simple way to create multiple layers of LSTM cells with tf.contrib.rnn.MultiRNNCell:

cell = tf.contrib.rnn.MultiRNNCell([drop] * lstm_layers)

Here, [drop] * lstm_layers creates a list of cells (drop) that is lstm_layers long. The MultiRNNCell wrapper builds this into multiple layers of RNN cells, one for each cell in the list.

So the final cell you're using in the network is actually multiple (or just one) LSTM cells with dropout. But it all works the same from an achitectural viewpoint, just a more complicated graph in the cell.

Exercise: Below, use tf.contrib.rnn.BasicLSTMCell to create an LSTM cell. Then, add drop out to it with tf.contrib.rnn.DropoutWrapper. Finally, create multiple LSTM layers with tf.contrib.rnn.MultiRNNCell.

Here is a tutorial on building RNNs that will help you out.


In [19]:
with graph.as_default():
    # Your basic LSTM cell
    lstm = tf.contrib.rnn.BasicLSTMCell(lstm_size)
    
    # Add dropout to the cell
    drop = tf.contrib.rnn.DropoutWrapper(lstm, output_keep_prob=keep_prob)
    
    # Stack up multiple LSTM layers, for deep learning
    cell = tf.contrib.rnn.MultiRNNCell([drop] * lstm_layers)
    
    # Getting an initial state of all zeros
    initial_state = cell.zero_state(batch_size, tf.float32)

RNN forward pass

Now we need to actually run the data through the RNN nodes. You can use tf.nn.dynamic_rnn to do this. You'd pass in the RNN cell you created (our multiple layered LSTM cell for instance), and the inputs to the network.

outputs, final_state = tf.nn.dynamic_rnn(cell, inputs, initial_state=initial_state)

Above I created an initial state, initial_state, to pass to the RNN. This is the cell state that is passed between the hidden layers in successive time steps. tf.nn.dynamic_rnn takes care of most of the work for us. We pass in our cell and the input to the cell, then it does the unrolling and everything else for us. It returns outputs for each time step and the final_state of the hidden layer.

Exercise: Use tf.nn.dynamic_rnn to add the forward pass through the RNN. Remember that we're actually passing in vectors from the embedding layer, embed.


In [20]:
with graph.as_default():
    outputs, final_state = tf.nn.dynamic_rnn(cell, embed,
                                             initial_state=initial_state)

Output

We only care about the final output, we'll be using that as our sentiment prediction. So we need to grab the last output with outputs[:, -1], the calculate the cost from that and labels_.


In [21]:
with graph.as_default():
    predictions = tf.contrib.layers.fully_connected(outputs[:, -1], 1, activation_fn=tf.sigmoid)
    cost = tf.losses.mean_squared_error(labels_, predictions)
    
    optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)

Validation accuracy

Here we can add a few nodes to calculate the accuracy which we'll use in the validation pass.


In [22]:
with graph.as_default():
    correct_pred = tf.equal(tf.cast(tf.round(predictions), tf.int32), labels_)
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

Batching

This is a simple function for returning batches from our data. First it removes data such that we only have full batches. Then it iterates through the x and y arrays and returns slices out of those arrays with size [batch_size].


In [23]:
def get_batches(x, y, batch_size=100):
    
    n_batches = len(x)//batch_size
    x, y = x[:n_batches*batch_size], y[:n_batches*batch_size]
    for ii in range(0, len(x), batch_size):
        yield x[ii:ii+batch_size], y[ii:ii+batch_size]

Training

Below is the typical training code. If you want to do this yourself, feel free to delete all this code and implement it yourself. Before you run this, make sure the checkpoints directory exists.


In [24]:
epochs = 10

with graph.as_default():
    saver = tf.train.Saver()

with tf.Session(graph=graph) as sess:
    sess.run(tf.global_variables_initializer())
    iteration = 1
    for e in range(epochs):
        state = sess.run(initial_state)
        
        for ii, (x, y) in enumerate(get_batches(train_x, train_y, batch_size), 1):
            feed = {inputs_: x,
                    labels_: y[:, None],
                    keep_prob: 0.5,
                    initial_state: state}
            loss, state, _ = sess.run([cost, final_state, optimizer], feed_dict=feed)
            
            if iteration%5==0:
                print("Epoch: {}/{}".format(e, epochs),
                      "Iteration: {}".format(iteration),
                      "Train loss: {:.3f}".format(loss))

            if iteration%25==0:
                val_acc = []
                val_state = sess.run(cell.zero_state(batch_size, tf.float32))
                for x, y in get_batches(val_x, val_y, batch_size):
                    feed = {inputs_: x,
                            labels_: y[:, None],
                            keep_prob: 1,
                            initial_state: val_state}
                    batch_acc, val_state = sess.run([accuracy, final_state], feed_dict=feed)
                    val_acc.append(batch_acc)
                print("Val acc: {:.3f}".format(np.mean(val_acc)))
            iteration +=1
    saver.save(sess, "checkpoints/sentiment.ckpt")


Epoch: 0/10 Iteration: 5 Train loss: 0.249
Epoch: 0/10 Iteration: 10 Train loss: 0.247
Epoch: 0/10 Iteration: 15 Train loss: 0.224
Epoch: 0/10 Iteration: 20 Train loss: 0.245
Epoch: 0/10 Iteration: 25 Train loss: 0.192
Val acc: 0.680
Epoch: 0/10 Iteration: 30 Train loss: 0.412
Epoch: 0/10 Iteration: 35 Train loss: 0.282
Epoch: 0/10 Iteration: 40 Train loss: 0.282
Epoch: 0/10 Iteration: 45 Train loss: 0.255
Epoch: 0/10 Iteration: 50 Train loss: 0.251
Val acc: 0.526
Epoch: 0/10 Iteration: 55 Train loss: 0.251
Epoch: 0/10 Iteration: 60 Train loss: 0.237
Epoch: 0/10 Iteration: 65 Train loss: 0.240
Epoch: 0/10 Iteration: 70 Train loss: 0.233
Epoch: 0/10 Iteration: 75 Train loss: 0.249
Val acc: 0.571
Epoch: 0/10 Iteration: 80 Train loss: 0.240
Epoch: 0/10 Iteration: 85 Train loss: 0.234
Epoch: 0/10 Iteration: 90 Train loss: 0.210
Epoch: 0/10 Iteration: 95 Train loss: 0.230
Epoch: 0/10 Iteration: 100 Train loss: 0.335
Val acc: 0.559
Epoch: 0/10 Iteration: 105 Train loss: 0.323
Epoch: 0/10 Iteration: 110 Train loss: 0.248
Epoch: 0/10 Iteration: 115 Train loss: 0.250
Epoch: 0/10 Iteration: 120 Train loss: 0.247
Epoch: 0/10 Iteration: 125 Train loss: 0.250
Val acc: 0.554
Epoch: 0/10 Iteration: 130 Train loss: 0.244
Epoch: 0/10 Iteration: 135 Train loss: 0.227
Epoch: 0/10 Iteration: 140 Train loss: 0.231
Epoch: 0/10 Iteration: 145 Train loss: 0.247
Epoch: 0/10 Iteration: 150 Train loss: 0.242
Val acc: 0.580
Epoch: 0/10 Iteration: 155 Train loss: 0.234
Epoch: 0/10 Iteration: 160 Train loss: 0.237
Epoch: 0/10 Iteration: 165 Train loss: 0.230
Epoch: 0/10 Iteration: 170 Train loss: 0.241
Epoch: 0/10 Iteration: 175 Train loss: 0.220
Val acc: 0.626
Epoch: 0/10 Iteration: 180 Train loss: 0.230
Epoch: 0/10 Iteration: 185 Train loss: 0.219
Epoch: 0/10 Iteration: 190 Train loss: 0.234
Epoch: 0/10 Iteration: 195 Train loss: 0.158
Epoch: 0/10 Iteration: 200 Train loss: 0.334
Val acc: 0.532
Epoch: 1/10 Iteration: 205 Train loss: 0.393
Epoch: 1/10 Iteration: 210 Train loss: 0.284
Epoch: 1/10 Iteration: 215 Train loss: 0.271
Epoch: 1/10 Iteration: 220 Train loss: 0.239
Epoch: 1/10 Iteration: 225 Train loss: 0.285
Val acc: 0.604
Epoch: 1/10 Iteration: 230 Train loss: 0.245
Epoch: 1/10 Iteration: 235 Train loss: 0.235
Epoch: 1/10 Iteration: 240 Train loss: 0.231
Epoch: 1/10 Iteration: 245 Train loss: 0.252
Epoch: 1/10 Iteration: 250 Train loss: 0.243
Val acc: 0.562
Epoch: 1/10 Iteration: 255 Train loss: 0.235
Epoch: 1/10 Iteration: 260 Train loss: 0.237
Epoch: 1/10 Iteration: 265 Train loss: 0.236
Epoch: 1/10 Iteration: 270 Train loss: 0.218
Epoch: 1/10 Iteration: 275 Train loss: 0.247
Val acc: 0.584
Epoch: 1/10 Iteration: 280 Train loss: 0.214
Epoch: 1/10 Iteration: 285 Train loss: 0.217
Epoch: 1/10 Iteration: 290 Train loss: 0.186
Epoch: 1/10 Iteration: 295 Train loss: 0.184
Epoch: 1/10 Iteration: 300 Train loss: 0.212
Val acc: 0.501
Epoch: 1/10 Iteration: 305 Train loss: 0.346
Epoch: 1/10 Iteration: 310 Train loss: 0.242
Epoch: 1/10 Iteration: 315 Train loss: 0.260
Epoch: 1/10 Iteration: 320 Train loss: 0.225
Epoch: 1/10 Iteration: 325 Train loss: 0.223
Val acc: 0.599
Epoch: 1/10 Iteration: 330 Train loss: 0.231
Epoch: 1/10 Iteration: 335 Train loss: 0.216
Epoch: 1/10 Iteration: 340 Train loss: 0.205
Epoch: 1/10 Iteration: 345 Train loss: 0.219
Epoch: 1/10 Iteration: 350 Train loss: 0.215
Val acc: 0.639
Epoch: 1/10 Iteration: 355 Train loss: 0.202
Epoch: 1/10 Iteration: 360 Train loss: 0.211
Epoch: 1/10 Iteration: 365 Train loss: 0.161
Epoch: 1/10 Iteration: 370 Train loss: 0.246
Epoch: 1/10 Iteration: 375 Train loss: 0.318
Val acc: 0.502
Epoch: 1/10 Iteration: 380 Train loss: 0.249
Epoch: 1/10 Iteration: 385 Train loss: 0.248
Epoch: 1/10 Iteration: 390 Train loss: 0.257
Epoch: 1/10 Iteration: 395 Train loss: 0.224
Epoch: 1/10 Iteration: 400 Train loss: 0.248
Val acc: 0.575
Epoch: 2/10 Iteration: 405 Train loss: 0.244
Epoch: 2/10 Iteration: 410 Train loss: 0.232
Epoch: 2/10 Iteration: 415 Train loss: 0.234
Epoch: 2/10 Iteration: 420 Train loss: 0.238
Epoch: 2/10 Iteration: 425 Train loss: 0.235
Val acc: 0.598
Epoch: 2/10 Iteration: 430 Train loss: 0.224
Epoch: 2/10 Iteration: 435 Train loss: 0.218
Epoch: 2/10 Iteration: 440 Train loss: 0.228
Epoch: 2/10 Iteration: 445 Train loss: 0.215
Epoch: 2/10 Iteration: 450 Train loss: 0.234
Val acc: 0.628
Epoch: 2/10 Iteration: 455 Train loss: 0.214
Epoch: 2/10 Iteration: 460 Train loss: 0.207
Epoch: 2/10 Iteration: 465 Train loss: 0.191
Epoch: 2/10 Iteration: 470 Train loss: 0.169
Epoch: 2/10 Iteration: 475 Train loss: 0.170
Val acc: 0.632
Epoch: 2/10 Iteration: 480 Train loss: 0.223
Epoch: 2/10 Iteration: 485 Train loss: 0.153
Epoch: 2/10 Iteration: 490 Train loss: 0.117
Epoch: 2/10 Iteration: 495 Train loss: 0.166
Epoch: 2/10 Iteration: 500 Train loss: 0.157
Val acc: 0.578
Epoch: 2/10 Iteration: 505 Train loss: 0.358
Epoch: 2/10 Iteration: 510 Train loss: 0.241
Epoch: 2/10 Iteration: 515 Train loss: 0.268
Epoch: 2/10 Iteration: 520 Train loss: 0.255
Epoch: 2/10 Iteration: 525 Train loss: 0.231
Val acc: 0.547
Epoch: 2/10 Iteration: 530 Train loss: 0.239
Epoch: 2/10 Iteration: 535 Train loss: 0.224
Epoch: 2/10 Iteration: 540 Train loss: 0.230
Epoch: 2/10 Iteration: 545 Train loss: 0.224
Epoch: 2/10 Iteration: 550 Train loss: 0.232
Val acc: 0.600
Epoch: 2/10 Iteration: 555 Train loss: 0.227
Epoch: 2/10 Iteration: 560 Train loss: 0.225
Epoch: 2/10 Iteration: 565 Train loss: 0.197
Epoch: 2/10 Iteration: 570 Train loss: 0.199
Epoch: 2/10 Iteration: 575 Train loss: 0.165
Val acc: 0.698
Epoch: 2/10 Iteration: 580 Train loss: 0.216
Epoch: 2/10 Iteration: 585 Train loss: 0.223
Epoch: 2/10 Iteration: 590 Train loss: 0.268
Epoch: 2/10 Iteration: 595 Train loss: 0.212
Epoch: 2/10 Iteration: 600 Train loss: 0.248
Val acc: 0.559
Epoch: 3/10 Iteration: 605 Train loss: 0.229
Epoch: 3/10 Iteration: 610 Train loss: 0.203
Epoch: 3/10 Iteration: 615 Train loss: 0.215
Epoch: 3/10 Iteration: 620 Train loss: 0.189
Epoch: 3/10 Iteration: 625 Train loss: 0.204
Val acc: 0.608
Epoch: 3/10 Iteration: 630 Train loss: 0.229
Epoch: 3/10 Iteration: 635 Train loss: 0.163
Epoch: 3/10 Iteration: 640 Train loss: 0.178
Epoch: 3/10 Iteration: 645 Train loss: 0.167
Epoch: 3/10 Iteration: 650 Train loss: 0.173
Val acc: 0.760
Epoch: 3/10 Iteration: 655 Train loss: 0.154
Epoch: 3/10 Iteration: 660 Train loss: 0.146
Epoch: 3/10 Iteration: 665 Train loss: 0.123
Epoch: 3/10 Iteration: 670 Train loss: 0.111
Epoch: 3/10 Iteration: 675 Train loss: 0.110
Val acc: 0.782
Epoch: 3/10 Iteration: 680 Train loss: 0.095
Epoch: 3/10 Iteration: 685 Train loss: 0.077
Epoch: 3/10 Iteration: 690 Train loss: 0.036
Epoch: 3/10 Iteration: 695 Train loss: 0.039
Epoch: 3/10 Iteration: 700 Train loss: 0.025
Val acc: 0.808
Epoch: 3/10 Iteration: 705 Train loss: 0.020
Epoch: 3/10 Iteration: 710 Train loss: 0.011
Epoch: 3/10 Iteration: 715 Train loss: 0.057
Epoch: 3/10 Iteration: 720 Train loss: 0.043
Epoch: 3/10 Iteration: 725 Train loss: 0.020
Val acc: 0.684
Epoch: 3/10 Iteration: 730 Train loss: 0.023
Epoch: 3/10 Iteration: 735 Train loss: 0.029
Epoch: 3/10 Iteration: 740 Train loss: 0.028
Epoch: 3/10 Iteration: 745 Train loss: 0.030
Epoch: 3/10 Iteration: 750 Train loss: 0.057
Val acc: 0.550
Epoch: 3/10 Iteration: 755 Train loss: 0.048
Epoch: 3/10 Iteration: 760 Train loss: 0.048
Epoch: 3/10 Iteration: 765 Train loss: 0.009
Epoch: 3/10 Iteration: 770 Train loss: 0.017
Epoch: 3/10 Iteration: 775 Train loss: 0.017
Val acc: 0.660
Epoch: 3/10 Iteration: 780 Train loss: 0.010
Epoch: 3/10 Iteration: 785 Train loss: 0.016
Epoch: 3/10 Iteration: 790 Train loss: 0.020
Epoch: 3/10 Iteration: 795 Train loss: 0.010
Epoch: 3/10 Iteration: 800 Train loss: 0.010
Val acc: 0.653
Epoch: 4/10 Iteration: 805 Train loss: 0.326
Epoch: 4/10 Iteration: 810 Train loss: 0.265
Epoch: 4/10 Iteration: 815 Train loss: 0.144
Epoch: 4/10 Iteration: 820 Train loss: 0.188
Epoch: 4/10 Iteration: 825 Train loss: 0.166
Val acc: 0.700
Epoch: 4/10 Iteration: 830 Train loss: 0.151
Epoch: 4/10 Iteration: 835 Train loss: 0.152
Epoch: 4/10 Iteration: 840 Train loss: 0.149
Epoch: 4/10 Iteration: 845 Train loss: 0.105
Epoch: 4/10 Iteration: 850 Train loss: 0.087
Val acc: 0.839
Epoch: 4/10 Iteration: 855 Train loss: 0.064
Epoch: 4/10 Iteration: 860 Train loss: 0.062
Epoch: 4/10 Iteration: 865 Train loss: 0.013
Epoch: 4/10 Iteration: 870 Train loss: 0.007
Epoch: 4/10 Iteration: 875 Train loss: 0.003
Val acc: 0.767
Epoch: 4/10 Iteration: 880 Train loss: 0.010
Epoch: 4/10 Iteration: 885 Train loss: 0.001
Epoch: 4/10 Iteration: 890 Train loss: 0.010
Epoch: 4/10 Iteration: 895 Train loss: 0.010
Epoch: 4/10 Iteration: 900 Train loss: 0.001
Val acc: 0.689
Epoch: 4/10 Iteration: 905 Train loss: 0.001
Epoch: 4/10 Iteration: 910 Train loss: 0.010
Epoch: 4/10 Iteration: 915 Train loss: 0.010
Epoch: 4/10 Iteration: 920 Train loss: 0.002
Epoch: 4/10 Iteration: 925 Train loss: 0.000
Val acc: 0.928
Epoch: 4/10 Iteration: 930 Train loss: 0.037
Epoch: 4/10 Iteration: 935 Train loss: 0.077
Epoch: 4/10 Iteration: 940 Train loss: 0.001
Epoch: 4/10 Iteration: 945 Train loss: 0.010
Epoch: 4/10 Iteration: 950 Train loss: 0.026
Val acc: 0.754
Epoch: 4/10 Iteration: 955 Train loss: 0.001
Epoch: 4/10 Iteration: 960 Train loss: 0.002
Epoch: 4/10 Iteration: 965 Train loss: 0.001
Epoch: 4/10 Iteration: 970 Train loss: 0.002
Epoch: 4/10 Iteration: 975 Train loss: 0.001
Val acc: 0.751
Epoch: 4/10 Iteration: 980 Train loss: 0.001
Epoch: 4/10 Iteration: 985 Train loss: 0.001
Epoch: 4/10 Iteration: 990 Train loss: 0.001
Epoch: 4/10 Iteration: 995 Train loss: 0.000
Epoch: 4/10 Iteration: 1000 Train loss: 0.001
Val acc: 0.692
Epoch: 5/10 Iteration: 1005 Train loss: 0.125
Epoch: 5/10 Iteration: 1010 Train loss: 0.114
Epoch: 5/10 Iteration: 1015 Train loss: 0.104
Epoch: 5/10 Iteration: 1020 Train loss: 0.046
Epoch: 5/10 Iteration: 1025 Train loss: 0.006
Val acc: 0.932
Epoch: 5/10 Iteration: 1030 Train loss: 0.007
Epoch: 5/10 Iteration: 1035 Train loss: 0.003
Epoch: 5/10 Iteration: 1040 Train loss: 0.022
Epoch: 5/10 Iteration: 1045 Train loss: 0.004
Epoch: 5/10 Iteration: 1050 Train loss: 0.007
Val acc: 0.946
Epoch: 5/10 Iteration: 1055 Train loss: 0.001
Epoch: 5/10 Iteration: 1060 Train loss: 0.002
Epoch: 5/10 Iteration: 1065 Train loss: 0.001
Epoch: 5/10 Iteration: 1070 Train loss: 0.001
Epoch: 5/10 Iteration: 1075 Train loss: 0.001
Val acc: 0.964
Epoch: 5/10 Iteration: 1080 Train loss: 0.002
Epoch: 5/10 Iteration: 1085 Train loss: 0.000
Epoch: 5/10 Iteration: 1090 Train loss: 0.000
Epoch: 5/10 Iteration: 1095 Train loss: 0.001
Epoch: 5/10 Iteration: 1100 Train loss: 0.010
Val acc: 0.944
Epoch: 5/10 Iteration: 1105 Train loss: 0.000
Epoch: 5/10 Iteration: 1110 Train loss: 0.000
Epoch: 5/10 Iteration: 1115 Train loss: 0.007
Epoch: 5/10 Iteration: 1120 Train loss: 0.003
Epoch: 5/10 Iteration: 1125 Train loss: 0.000
Val acc: 0.937
Epoch: 5/10 Iteration: 1130 Train loss: 0.001
Epoch: 5/10 Iteration: 1135 Train loss: 0.000
Epoch: 5/10 Iteration: 1140 Train loss: 0.000
Epoch: 5/10 Iteration: 1145 Train loss: 0.010
Epoch: 5/10 Iteration: 1150 Train loss: 0.000
Val acc: 0.942
Epoch: 5/10 Iteration: 1155 Train loss: 0.000
Epoch: 5/10 Iteration: 1160 Train loss: 0.000
Epoch: 5/10 Iteration: 1165 Train loss: 0.000
Epoch: 5/10 Iteration: 1170 Train loss: 0.001
Epoch: 5/10 Iteration: 1175 Train loss: 0.000
Val acc: 0.958
Epoch: 5/10 Iteration: 1180 Train loss: 0.000
Epoch: 5/10 Iteration: 1185 Train loss: 0.001
Epoch: 5/10 Iteration: 1190 Train loss: 0.000
Epoch: 5/10 Iteration: 1195 Train loss: 0.000
Epoch: 5/10 Iteration: 1200 Train loss: 0.000
Val acc: 0.962
Epoch: 6/10 Iteration: 1205 Train loss: 0.022
Epoch: 6/10 Iteration: 1210 Train loss: 0.031
Epoch: 6/10 Iteration: 1215 Train loss: 0.071
Epoch: 6/10 Iteration: 1220 Train loss: 0.016
Epoch: 6/10 Iteration: 1225 Train loss: 0.002
Val acc: 0.953
Epoch: 6/10 Iteration: 1230 Train loss: 0.007
Epoch: 6/10 Iteration: 1235 Train loss: 0.003
Epoch: 6/10 Iteration: 1240 Train loss: 0.010
Epoch: 6/10 Iteration: 1245 Train loss: 0.000
Epoch: 6/10 Iteration: 1250 Train loss: 0.004
Val acc: 0.952
Epoch: 6/10 Iteration: 1255 Train loss: 0.000
Epoch: 6/10 Iteration: 1260 Train loss: 0.001
Epoch: 6/10 Iteration: 1265 Train loss: 0.000
Epoch: 6/10 Iteration: 1270 Train loss: 0.000
Epoch: 6/10 Iteration: 1275 Train loss: 0.000
Val acc: 0.949
Epoch: 6/10 Iteration: 1280 Train loss: 0.000
Epoch: 6/10 Iteration: 1285 Train loss: 0.000
Epoch: 6/10 Iteration: 1290 Train loss: 0.000
Epoch: 6/10 Iteration: 1295 Train loss: 0.002
Epoch: 6/10 Iteration: 1300 Train loss: 0.001
Val acc: 0.955
Epoch: 6/10 Iteration: 1305 Train loss: 0.000
Epoch: 6/10 Iteration: 1310 Train loss: 0.000
Epoch: 6/10 Iteration: 1315 Train loss: 0.000
Epoch: 6/10 Iteration: 1320 Train loss: 0.000
Epoch: 6/10 Iteration: 1325 Train loss: 0.000
Val acc: 0.958
Epoch: 6/10 Iteration: 1330 Train loss: 0.000
Epoch: 6/10 Iteration: 1335 Train loss: 0.000
Epoch: 6/10 Iteration: 1340 Train loss: 0.000
Epoch: 6/10 Iteration: 1345 Train loss: 0.000
Epoch: 6/10 Iteration: 1350 Train loss: 0.000
Val acc: 0.946
Epoch: 6/10 Iteration: 1355 Train loss: 0.000
Epoch: 6/10 Iteration: 1360 Train loss: 0.000
Epoch: 6/10 Iteration: 1365 Train loss: 0.000
Epoch: 6/10 Iteration: 1370 Train loss: 0.000
Epoch: 6/10 Iteration: 1375 Train loss: 0.000
Val acc: 0.944
Epoch: 6/10 Iteration: 1380 Train loss: 0.000
Epoch: 6/10 Iteration: 1385 Train loss: 0.000
Epoch: 6/10 Iteration: 1390 Train loss: 0.000
Epoch: 6/10 Iteration: 1395 Train loss: 0.000
Epoch: 6/10 Iteration: 1400 Train loss: 0.000
Val acc: 0.943
Epoch: 7/10 Iteration: 1405 Train loss: 0.010
Epoch: 7/10 Iteration: 1410 Train loss: 0.016
Epoch: 7/10 Iteration: 1415 Train loss: 0.017
Epoch: 7/10 Iteration: 1420 Train loss: 0.010
Epoch: 7/10 Iteration: 1425 Train loss: 0.001
Val acc: 0.894
Epoch: 7/10 Iteration: 1430 Train loss: 0.000
Epoch: 7/10 Iteration: 1435 Train loss: 0.001
Epoch: 7/10 Iteration: 1440 Train loss: 0.011
Epoch: 7/10 Iteration: 1445 Train loss: 0.000
Epoch: 7/10 Iteration: 1450 Train loss: 0.017
Val acc: 0.944
Epoch: 7/10 Iteration: 1455 Train loss: 0.001
Epoch: 7/10 Iteration: 1460 Train loss: 0.001
Epoch: 7/10 Iteration: 1465 Train loss: 0.000
Epoch: 7/10 Iteration: 1470 Train loss: 0.000
Epoch: 7/10 Iteration: 1475 Train loss: 0.000
Val acc: 0.962
Epoch: 7/10 Iteration: 1480 Train loss: 0.001
Epoch: 7/10 Iteration: 1485 Train loss: 0.001
Epoch: 7/10 Iteration: 1490 Train loss: 0.000
Epoch: 7/10 Iteration: 1495 Train loss: 0.000
Epoch: 7/10 Iteration: 1500 Train loss: 0.000
Val acc: 0.955
Epoch: 7/10 Iteration: 1505 Train loss: 0.000
Epoch: 7/10 Iteration: 1510 Train loss: 0.000
Epoch: 7/10 Iteration: 1515 Train loss: 0.000
Epoch: 7/10 Iteration: 1520 Train loss: 0.000
Epoch: 7/10 Iteration: 1525 Train loss: 0.000
Val acc: 0.945
Epoch: 7/10 Iteration: 1530 Train loss: 0.000
Epoch: 7/10 Iteration: 1535 Train loss: 0.000
Epoch: 7/10 Iteration: 1540 Train loss: 0.000
Epoch: 7/10 Iteration: 1545 Train loss: 0.007
Epoch: 7/10 Iteration: 1550 Train loss: 0.002
Val acc: 0.937
Epoch: 7/10 Iteration: 1555 Train loss: 0.000
Epoch: 7/10 Iteration: 1560 Train loss: 0.000
Epoch: 7/10 Iteration: 1565 Train loss: 0.002
Epoch: 7/10 Iteration: 1570 Train loss: 0.003
Epoch: 7/10 Iteration: 1575 Train loss: 0.000
Val acc: 0.961
Epoch: 7/10 Iteration: 1580 Train loss: 0.008
Epoch: 7/10 Iteration: 1585 Train loss: 0.001
Epoch: 7/10 Iteration: 1590 Train loss: 0.000
Epoch: 7/10 Iteration: 1595 Train loss: 0.000
Epoch: 7/10 Iteration: 1600 Train loss: 0.001
Val acc: 0.961
Epoch: 8/10 Iteration: 1605 Train loss: 0.008
Epoch: 8/10 Iteration: 1610 Train loss: 0.000
Epoch: 8/10 Iteration: 1615 Train loss: 0.006
Epoch: 8/10 Iteration: 1620 Train loss: 0.000
Epoch: 8/10 Iteration: 1625 Train loss: 0.000
Val acc: 0.910
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-24-da10d9fe9550> in <module>()
     15                     keep_prob: 0.5,
     16                     initial_state: state}
---> 17             loss, state, _ = sess.run([cost, final_state, optimizer], feed_dict=feed)
     18 
     19             if iteration%5==0:

/home/luo/anaconda2/envs/dlnd/lib/python3.6/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    765     try:
    766       result = self._run(None, fetches, feed_dict, options_ptr,
--> 767                          run_metadata_ptr)
    768       if run_metadata:
    769         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/home/luo/anaconda2/envs/dlnd/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
    963     if final_fetches or final_targets:
    964       results = self._do_run(handle, final_targets, final_fetches,
--> 965                              feed_dict_string, options, run_metadata)
    966     else:
    967       results = []

/home/luo/anaconda2/envs/dlnd/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1013     if handle is None:
   1014       return self._do_call(_run_fn, self._session, feed_dict, fetch_list,
-> 1015                            target_list, options, run_metadata)
   1016     else:
   1017       return self._do_call(_prun_fn, self._session, handle, feed_dict,

/home/luo/anaconda2/envs/dlnd/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
   1020   def _do_call(self, fn, *args):
   1021     try:
-> 1022       return fn(*args)
   1023     except errors.OpError as e:
   1024       message = compat.as_text(e.message)

/home/luo/anaconda2/envs/dlnd/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
   1002         return tf_session.TF_Run(session, options,
   1003                                  feed_dict, fetch_list, target_list,
-> 1004                                  status, run_metadata)
   1005 
   1006     def _prun_fn(session, handle, feed_dict, fetch_list):

KeyboardInterrupt: 

Testing


In [ ]:
test_acc = []
with tf.Session(graph=graph) as sess:
    saver.restore(sess, tf.train.latest_checkpoint('checkpoints'))
    test_state = sess.run(cell.zero_state(batch_size, tf.float32))
    for ii, (x, y) in enumerate(get_batches(test_x, test_y, batch_size), 1):
        feed = {inputs_: x,
                labels_: y[:, None],
                keep_prob: 1,
                initial_state: test_state}
        batch_acc, test_state = sess.run([accuracy, final_state], feed_dict=feed)
        test_acc.append(batch_acc)
    print("Test accuracy: {:.3f}".format(np.mean(test_acc)))