In [1]:
import tensorflow as tf
import numpy as np
import re
import collections
import sklearn.metrics as sk
from helper_functions_wsj import *
from glob import glob
from reader import Reader
import time

%load_ext autoreload
%autoreload 2

In [2]:
print('Loading WSJ Data')
reader = Reader(split=0.9)
(X_train, Y_train, mask_train,
 X_test, Y_test, mask_test) = \
    reader.get_data(glob('./data/WSJ/*/*.POS'))
print('Loaded WSJ Data')


Loading WSJ Data
extended [('**start**', 'START'), ('It', 'PRP'), ('has', 'VBZ'), ('no', 'DT'), ('bearing', 'NN'), ('on', 'IN'), ('our', 'PRP$'), ('work', 'NN'), ('force', 'NN'), ('today', 'NN'), ('.', '.'), ("''", "''"), ('**end**', 'END')]
extended [('**start**', 'START'), ('He', 'PRP'), ('predicted', 'VBD'), ('the', 'DT'), ('problem', 'NN'), ('will', 'MD'), ('be', 'VB'), ('solved', 'VBN'), ('``', '``'), ('very', 'RB'), ('soon', 'RB'), ('.', '.'), ("''", "''"), ('**end**', 'END')]
pad all sentences to 36
Loaded WSJ Data

In [3]:
graph = tf.Graph()
with graph.as_default():
    batch_size = 32
    hidden_size = 128
    num_layers = 3
    vocab_size = len(reader.word_to_id)
    tag_size = len(reader.tag_to_id)
    maxlen = reader.maxlen

    input_data = tf.placeholder(tf.int64, [None, maxlen])
    targets = tf.placeholder(tf.int64, [None, maxlen])
    mask = tf.placeholder(tf.bool, [None, maxlen])

    lstm_cell = tf.nn.rnn_cell.LSTMCell(hidden_size, state_is_tuple=True)
    # if is_training and dropout_keep_prob < 1:
    #     lstm_cell = tf.nn.rnn_cell.DropoutWrapper(
    #         lstm_cell, output_keep_prob=dropout_keep_prob)

    cell_fw = tf.nn.rnn_cell.MultiRNNCell([lstm_cell] * num_layers, state_is_tuple=True)
    cell_bw = tf.nn.rnn_cell.MultiRNNCell([lstm_cell] * num_layers, state_is_tuple=True)

    initial_state_fw = cell_fw.zero_state(tf.shape(input_data)[0], tf.float32)
    initial_state_bw = cell_bw.zero_state(tf.shape(input_data)[0], tf.float32)

    with tf.device("/cpu:0"):
        embedding = tf.get_variable("embedding", [vocab_size,
                                                  hidden_size])
        inputs = tf.nn.embedding_lookup(embedding, input_data)

    inputs = [input_ for input_ in tf.unpack(tf.transpose(inputs, [1, 0, 2]))]
    # if is_training and dropout_keep_prob < 1:
    #     inputs = tf.nn.dropout(tf.pack(inputs), dropout_keep_prob)
    #     inputs = tf.unpack(inputs)
    outputs, _, _ = tf.nn.bidirectional_rnn(cell_fw, cell_bw, inputs,
                                            initial_state_fw=initial_state_fw,
                                            initial_state_bw=initial_state_bw)

    # output from forward and backward cells.
    output = tf.reshape(tf.concat(1, outputs), [-1, 2 * hidden_size])
    softmax_w = tf.get_variable("softmax_w", [2 * hidden_size, tag_size])
    softmax_b = tf.get_variable("softmax_b", [tag_size])
    logits = tf.matmul(output, softmax_w) + softmax_b
    loss = tf.nn.seq2seq.sequence_loss_by_example(
        [logits], [tf.reshape(targets, [-1])],
        [tf.reshape(tf.cast(mask, tf.float32), [-1])], tag_size)
    cost = tf.reduce_sum(loss) / batch_size

    equality = tf.equal(tf.argmax(logits, 1),
                        tf.cast(tf.reshape(targets, [-1]), tf.int64))
    masked = tf.boolean_mask(equality, tf.reshape(mask, [-1]))
    misclass = 1 - tf.reduce_mean(tf.cast(masked, tf.float32))

    lr = tf.Variable(0.0, trainable=False)
    tvars = tf.trainable_variables()
    grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars), 5.0)
    optimizer = tf.train.GradientDescentOptimizer(lr)
    train_op = optimizer.apply_gradients(zip(grads, tvars))

def assign_lr(session, lr_value):
    session.run(tf.assign(lr, lr_value))
    
def run_epoch(x_data, y_data, data_mask, eval_op, training=True, verbose=False):
    """Runs the model on the given data."""
    epoch_size = ((len(x_data) // batch_size) - 1)
    start_time = time.time()
    costs = 0.0
    iters = 0
    misclass_ = []
    for step, (x, y, data_mask) in enumerate(Reader.iterator(x_data, y_data, data_mask, batch_size)):
        if training is True:
            l, misclassifications, _ = sess.run([cost, misclass, eval_op],
                                                   {input_data: x, targets: y, mask: data_mask})
        else:
            l, misclassifications = sess.run([cost, misclass],
                                                {input_data: x, targets: y, mask: data_mask})
        costs += l
        iters += batch_size

        if verbose and step % (epoch_size // 10) == 0:
            print("[%s] %.3f perplexity: %.3f misclass:%.3f speed: %.0f wps" %
                  ('train' if training else 'test', step * 1.0 / epoch_size,
                   np.exp(costs / iters), misclassifications,
                   iters * batch_size / (time.time() - start_time)))
        misclass_.append(misclassifications)
    return np.exp(costs / iters), np.mean(misclass_)

In [4]:
sess = tf.InteractiveSession(graph=graph)
tf.initialize_all_variables().run()
print('Initialized')

# create saver for model
saver = tf.train.Saver(max_to_keep=1)


Initialized

In [17]:
sess.close()

In [5]:
best_misclass = 1.0

for i in range(10):
    lr_decay = 0.5 ** max(i - 4, 0.0)
    assign_lr(sess, 1.0 * lr_decay)

    print("Epoch: %d Learning rate: %.3f" % (i + 1, sess.run(lr)))
    train_perplexity, _ = run_epoch(X_train, Y_train, mask_train,
                                    train_op, verbose=True)
    _, misclassifications = run_epoch(X_test, Y_test, mask_test,
                            tf.no_op(), training=False, verbose=True)
    if misclassifications < best_misclass:
        best_misclass = misclassifications
        saver.save(sess, './data/bid3rnn_tagger.ckpt', global_step=i)
        print('Saving')


Epoch: 1 Learning rate: 1.000
[train] 0.000 perplexity: 7.755 misclass:0.998 speed: 930 wps
[train] 0.100 perplexity: 5.157 misclass:0.822 speed: 7597 wps
[train] 0.200 perplexity: 4.398 misclass:0.717 speed: 7752 wps
[train] 0.299 perplexity: 3.409 misclass:0.341 speed: 7802 wps
[train] 0.399 perplexity: 2.744 misclass:0.169 speed: 7831 wps
[train] 0.499 perplexity: 2.332 misclass:0.090 speed: 7828 wps
[train] 0.599 perplexity: 2.079 misclass:0.065 speed: 7838 wps
[train] 0.699 perplexity: 1.906 misclass:0.049 speed: 7848 wps
[train] 0.799 perplexity: 1.783 misclass:0.093 speed: 7850 wps
[train] 0.898 perplexity: 1.691 misclass:0.041 speed: 7853 wps
[train] 0.998 perplexity: 1.619 misclass:0.037 speed: 7853 wps
[test] 0.000 perplexity: 1.126 misclass:0.067 speed: 2055 wps
[test] 0.098 perplexity: 1.101 misclass:0.067 speed: 17384 wps
[test] 0.196 perplexity: 1.099 misclass:0.055 speed: 21542 wps
[test] 0.294 perplexity: 1.100 misclass:0.056 speed: 23489 wps
[test] 0.392 perplexity: 1.099 misclass:0.043 speed: 24536 wps
[test] 0.490 perplexity: 1.096 misclass:0.053 speed: 25233 wps
[test] 0.588 perplexity: 1.097 misclass:0.050 speed: 25679 wps
[test] 0.686 perplexity: 1.097 misclass:0.064 speed: 26012 wps
[test] 0.784 perplexity: 1.098 misclass:0.062 speed: 26325 wps
[test] 0.881 perplexity: 1.097 misclass:0.069 speed: 26546 wps
[test] 0.979 perplexity: 1.098 misclass:0.053 speed: 26734 wps
Saving
Epoch: 2 Learning rate: 1.000
[train] 0.000 perplexity: 1.060 misclass:0.057 speed: 7918 wps
[train] 0.100 perplexity: 1.094 misclass:0.047 speed: 7889 wps
[train] 0.200 perplexity: 1.090 misclass:0.057 speed: 7888 wps
[train] 0.299 perplexity: 1.088 misclass:0.035 speed: 7884 wps
[train] 0.399 perplexity: 1.085 misclass:0.055 speed: 7880 wps
[train] 0.499 perplexity: 1.083 misclass:0.034 speed: 7879 wps
[train] 0.599 perplexity: 1.081 misclass:0.040 speed: 7875 wps
[train] 0.699 perplexity: 1.080 misclass:0.039 speed: 7869 wps
[train] 0.799 perplexity: 1.078 misclass:0.062 speed: 7858 wps
[train] 0.898 perplexity: 1.077 misclass:0.037 speed: 7854 wps
[train] 0.998 perplexity: 1.076 misclass:0.035 speed: 7852 wps
[test] 0.000 perplexity: 1.089 misclass:0.053 speed: 7320 wps
[test] 0.098 perplexity: 1.078 misclass:0.045 speed: 25109 wps
[test] 0.196 perplexity: 1.076 misclass:0.046 speed: 26668 wps
[test] 0.294 perplexity: 1.077 misclass:0.052 speed: 27221 wps
[test] 0.392 perplexity: 1.076 misclass:0.029 speed: 27600 wps
[test] 0.490 perplexity: 1.073 misclass:0.035 speed: 27809 wps
[test] 0.588 perplexity: 1.074 misclass:0.038 speed: 27953 wps
[test] 0.686 perplexity: 1.073 misclass:0.056 speed: 28077 wps
[test] 0.784 perplexity: 1.074 misclass:0.046 speed: 28101 wps
[test] 0.881 perplexity: 1.074 misclass:0.056 speed: 28107 wps
[test] 0.979 perplexity: 1.074 misclass:0.051 speed: 28147 wps
Saving
Epoch: 3 Learning rate: 1.000
[train] 0.000 perplexity: 1.047 misclass:0.036 speed: 7094 wps
[train] 0.100 perplexity: 1.065 misclass:0.025 speed: 7871 wps
[train] 0.200 perplexity: 1.063 misclass:0.024 speed: 7858 wps
[train] 0.299 perplexity: 1.063 misclass:0.033 speed: 7851 wps
[train] 0.399 perplexity: 1.062 misclass:0.049 speed: 7849 wps
[train] 0.499 perplexity: 1.061 misclass:0.030 speed: 7846 wps
[train] 0.599 perplexity: 1.061 misclass:0.047 speed: 7848 wps
[train] 0.699 perplexity: 1.060 misclass:0.033 speed: 7845 wps
[train] 0.799 perplexity: 1.060 misclass:0.052 speed: 7845 wps
[train] 0.898 perplexity: 1.059 misclass:0.024 speed: 7841 wps
[train] 0.998 perplexity: 1.059 misclass:0.033 speed: 7837 wps
[test] 0.000 perplexity: 1.078 misclass:0.038 speed: 7592 wps
[test] 0.098 perplexity: 1.071 misclass:0.045 speed: 25310 wps
[test] 0.196 perplexity: 1.071 misclass:0.044 speed: 26644 wps
[test] 0.294 perplexity: 1.072 misclass:0.048 speed: 27089 wps
[test] 0.392 perplexity: 1.070 misclass:0.035 speed: 27415 wps
[test] 0.490 perplexity: 1.068 misclass:0.028 speed: 27643 wps
[test] 0.588 perplexity: 1.068 misclass:0.031 speed: 27744 wps
[test] 0.686 perplexity: 1.067 misclass:0.053 speed: 27826 wps
[test] 0.784 perplexity: 1.068 misclass:0.041 speed: 27913 wps
[test] 0.881 perplexity: 1.068 misclass:0.047 speed: 27974 wps
[test] 0.979 perplexity: 1.068 misclass:0.042 speed: 28011 wps
Saving
Epoch: 4 Learning rate: 1.000
[train] 0.000 perplexity: 1.041 misclass:0.023 speed: 7688 wps
[train] 0.100 perplexity: 1.055 misclass:0.025 speed: 7863 wps
[train] 0.200 perplexity: 1.053 misclass:0.026 speed: 7749 wps
[train] 0.299 perplexity: 1.053 misclass:0.031 speed: 7035 wps
[train] 0.399 perplexity: 1.053 misclass:0.053 speed: 6935 wps
[train] 0.499 perplexity: 1.052 misclass:0.032 speed: 6963 wps
[train] 0.599 perplexity: 1.052 misclass:0.038 speed: 6997 wps
[train] 0.699 perplexity: 1.052 misclass:0.029 speed: 7082 wps
[train] 0.799 perplexity: 1.052 misclass:0.050 speed: 6929 wps
[train] 0.898 perplexity: 1.052 misclass:0.020 speed: 6963 wps
[train] 0.998 perplexity: 1.051 misclass:0.031 speed: 6989 wps
[test] 0.000 perplexity: 1.075 misclass:0.042 speed: 7413 wps
[test] 0.098 perplexity: 1.070 misclass:0.045 speed: 25238 wps
[test] 0.196 perplexity: 1.070 misclass:0.042 speed: 26751 wps
[test] 0.294 perplexity: 1.071 misclass:0.044 speed: 27309 wps
[test] 0.392 perplexity: 1.069 misclass:0.027 speed: 27702 wps
[test] 0.490 perplexity: 1.067 misclass:0.024 speed: 27709 wps
[test] 0.588 perplexity: 1.067 misclass:0.036 speed: 27371 wps
[test] 0.686 perplexity: 1.066 misclass:0.053 speed: 27515 wps
[test] 0.784 perplexity: 1.067 misclass:0.041 speed: 27632 wps
[test] 0.881 perplexity: 1.067 misclass:0.047 speed: 27710 wps
[test] 0.979 perplexity: 1.067 misclass:0.044 speed: 27824 wps
Saving
Epoch: 5 Learning rate: 1.000
[train] 0.000 perplexity: 1.036 misclass:0.023 speed: 5721 wps
[train] 0.100 perplexity: 1.049 misclass:0.025 speed: 6708 wps
[train] 0.200 perplexity: 1.048 misclass:0.026 speed: 7108 wps
[train] 0.299 perplexity: 1.048 misclass:0.029 speed: 7121 wps
[train] 0.399 perplexity: 1.048 misclass:0.045 speed: 7166 wps
[train] 0.499 perplexity: 1.048 misclass:0.026 speed: 7208 wps
[train] 0.599 perplexity: 1.047 misclass:0.031 speed: 7233 wps
[train] 0.699 perplexity: 1.047 misclass:0.029 speed: 7118 wps
[train] 0.799 perplexity: 1.047 misclass:0.052 speed: 6437 wps
[train] 0.898 perplexity: 1.047 misclass:0.020 speed: 6005 wps
[train] 0.998 perplexity: 1.047 misclass:0.029 speed: 5703 wps
[test] 0.000 perplexity: 1.070 misclass:0.042 speed: 5855 wps
[test] 0.098 perplexity: 1.068 misclass:0.049 speed: 11833 wps
[test] 0.196 perplexity: 1.068 misclass:0.044 speed: 11683 wps
[test] 0.294 perplexity: 1.069 misclass:0.036 speed: 12239 wps
[test] 0.392 perplexity: 1.067 misclass:0.027 speed: 12500 wps
[test] 0.490 perplexity: 1.065 misclass:0.026 speed: 12487 wps
[test] 0.588 perplexity: 1.065 misclass:0.033 speed: 12523 wps
[test] 0.686 perplexity: 1.065 misclass:0.058 speed: 12470 wps
[test] 0.784 perplexity: 1.066 misclass:0.046 speed: 12415 wps
[test] 0.881 perplexity: 1.066 misclass:0.049 speed: 12458 wps
[test] 0.979 perplexity: 1.066 misclass:0.042 speed: 12576 wps
Saving
Epoch: 6 Learning rate: 0.500
[train] 0.000 perplexity: 1.032 misclass:0.023 speed: 4471 wps
[train] 0.100 perplexity: 1.043 misclass:0.027 speed: 4084 wps
[train] 0.200 perplexity: 1.042 misclass:0.020 speed: 4144 wps
[train] 0.299 perplexity: 1.042 misclass:0.020 speed: 4087 wps
[train] 0.399 perplexity: 1.041 misclass:0.043 speed: 4053 wps
[train] 0.499 perplexity: 1.041 misclass:0.024 speed: 4134 wps
[train] 0.599 perplexity: 1.040 misclass:0.031 speed: 4171 wps
[train] 0.699 perplexity: 1.040 misclass:0.028 speed: 4186 wps
[train] 0.799 perplexity: 1.039 misclass:0.048 speed: 4175 wps
[train] 0.898 perplexity: 1.039 misclass:0.022 speed: 4179 wps
[train] 0.998 perplexity: 1.039 misclass:0.025 speed: 4208 wps
[test] 0.000 perplexity: 1.073 misclass:0.046 speed: 5895 wps
[test] 0.098 perplexity: 1.068 misclass:0.049 speed: 14557 wps
[test] 0.196 perplexity: 1.067 misclass:0.036 speed: 15220 wps
[test] 0.294 perplexity: 1.068 misclass:0.040 speed: 15434 wps
[test] 0.392 perplexity: 1.065 misclass:0.027 speed: 15548 wps
[test] 0.490 perplexity: 1.063 misclass:0.022 speed: 15613 wps
[test] 0.588 perplexity: 1.063 misclass:0.038 speed: 15626 wps
[test] 0.686 perplexity: 1.063 misclass:0.051 speed: 15677 wps
[test] 0.784 perplexity: 1.064 misclass:0.039 speed: 15653 wps
[test] 0.881 perplexity: 1.063 misclass:0.039 speed: 15699 wps
[test] 0.979 perplexity: 1.064 misclass:0.036 speed: 15599 wps
Saving
Epoch: 7 Learning rate: 0.250
[train] 0.000 perplexity: 1.026 misclass:0.019 speed: 4576 wps
[train] 0.100 perplexity: 1.037 misclass:0.025 speed: 4010 wps
[train] 0.200 perplexity: 1.036 misclass:0.018 speed: 4214 wps
[train] 0.299 perplexity: 1.036 misclass:0.018 speed: 4217 wps
[train] 0.399 perplexity: 1.036 misclass:0.041 speed: 4248 wps
[train] 0.499 perplexity: 1.035 misclass:0.028 speed: 4278 wps
[train] 0.599 perplexity: 1.035 misclass:0.029 speed: 4308 wps
[train] 0.699 perplexity: 1.035 misclass:0.026 speed: 4290 wps
[train] 0.799 perplexity: 1.034 misclass:0.044 speed: 4487 wps
[train] 0.898 perplexity: 1.034 misclass:0.015 speed: 4709 wps
[train] 0.998 perplexity: 1.033 misclass:0.023 speed: 4895 wps
[test] 0.000 perplexity: 1.072 misclass:0.049 speed: 7308 wps
[test] 0.098 perplexity: 1.069 misclass:0.047 speed: 24936 wps
[test] 0.196 perplexity: 1.068 misclass:0.029 speed: 26516 wps
[test] 0.294 perplexity: 1.068 misclass:0.040 speed: 26936 wps
[test] 0.392 perplexity: 1.066 misclass:0.027 speed: 27119 wps
[test] 0.490 perplexity: 1.063 misclass:0.024 speed: 27244 wps
[test] 0.588 perplexity: 1.064 misclass:0.039 speed: 27274 wps
[test] 0.686 perplexity: 1.063 misclass:0.053 speed: 27271 wps
[test] 0.784 perplexity: 1.064 misclass:0.039 speed: 27309 wps
[test] 0.881 perplexity: 1.064 misclass:0.032 speed: 27425 wps
[test] 0.979 perplexity: 1.064 misclass:0.038 speed: 27484 wps
Saving
Epoch: 8 Learning rate: 0.125
[train] 0.000 perplexity: 1.023 misclass:0.019 speed: 4884 wps
[train] 0.100 perplexity: 1.034 misclass:0.019 speed: 7299 wps
[train] 0.200 perplexity: 1.033 misclass:0.016 speed: 7307 wps
[train] 0.299 perplexity: 1.033 misclass:0.018 speed: 7408 wps
[train] 0.399 perplexity: 1.032 misclass:0.041 speed: 7439 wps
[train] 0.499 perplexity: 1.032 misclass:0.024 speed: 7430 wps
[train] 0.599 perplexity: 1.032 misclass:0.025 speed: 7383 wps
[train] 0.699 perplexity: 1.031 misclass:0.025 speed: 7384 wps
[train] 0.799 perplexity: 1.031 misclass:0.035 speed: 7259 wps
[train] 0.898 perplexity: 1.031 misclass:0.015 speed: 7270 wps
[train] 0.998 perplexity: 1.030 misclass:0.023 speed: 7288 wps
[test] 0.000 perplexity: 1.073 misclass:0.051 speed: 7286 wps
[test] 0.098 perplexity: 1.070 misclass:0.041 speed: 24001 wps
[test] 0.196 perplexity: 1.069 misclass:0.032 speed: 25801 wps
[test] 0.294 perplexity: 1.069 misclass:0.042 speed: 26442 wps
[test] 0.392 perplexity: 1.066 misclass:0.027 speed: 26589 wps
[test] 0.490 perplexity: 1.064 misclass:0.022 speed: 26950 wps
[test] 0.588 perplexity: 1.064 misclass:0.038 speed: 27208 wps
[test] 0.686 perplexity: 1.064 misclass:0.053 speed: 27341 wps
[test] 0.784 perplexity: 1.065 misclass:0.034 speed: 27302 wps
[test] 0.881 perplexity: 1.064 misclass:0.032 speed: 27040 wps
[test] 0.979 perplexity: 1.064 misclass:0.036 speed: 26427 wps
Saving
Epoch: 9 Learning rate: 0.062
[train] 0.000 perplexity: 1.022 misclass:0.017 speed: 7450 wps
[train] 0.100 perplexity: 1.031 misclass:0.017 speed: 7661 wps
[train] 0.200 perplexity: 1.031 misclass:0.018 speed: 7665 wps
[train] 0.299 perplexity: 1.031 misclass:0.015 speed: 7650 wps
[train] 0.399 perplexity: 1.030 misclass:0.039 speed: 7548 wps
[train] 0.499 perplexity: 1.030 misclass:0.023 speed: 7488 wps
[train] 0.599 perplexity: 1.030 misclass:0.027 speed: 7430 wps
[train] 0.699 perplexity: 1.030 misclass:0.023 speed: 7470 wps
[train] 0.799 perplexity: 1.029 misclass:0.035 speed: 7366 wps
[train] 0.898 perplexity: 1.029 misclass:0.013 speed: 7211 wps
[train] 0.998 perplexity: 1.028 misclass:0.018 speed: 7206 wps
[test] 0.000 perplexity: 1.073 misclass:0.051 speed: 7315 wps
[test] 0.098 perplexity: 1.070 misclass:0.045 speed: 24555 wps
[test] 0.196 perplexity: 1.069 misclass:0.034 speed: 26325 wps
[test] 0.294 perplexity: 1.069 misclass:0.042 speed: 26896 wps
[test] 0.392 perplexity: 1.067 misclass:0.025 speed: 27075 wps
[test] 0.490 perplexity: 1.064 misclass:0.024 speed: 27003 wps
[test] 0.588 perplexity: 1.064 misclass:0.036 speed: 27165 wps
[test] 0.686 perplexity: 1.064 misclass:0.053 speed: 27247 wps
[test] 0.784 perplexity: 1.065 misclass:0.037 speed: 27238 wps
[test] 0.881 perplexity: 1.064 misclass:0.030 speed: 27377 wps
[test] 0.979 perplexity: 1.065 misclass:0.036 speed: 27285 wps
Saving
Epoch: 10 Learning rate: 0.031
[train] 0.000 perplexity: 1.020 misclass:0.019 speed: 7753 wps
[train] 0.100 perplexity: 1.030 misclass:0.016 speed: 7201 wps
[train] 0.200 perplexity: 1.030 misclass:0.016 speed: 7321 wps
[train] 0.299 perplexity: 1.030 misclass:0.018 speed: 7260 wps
[train] 0.399 perplexity: 1.029 misclass:0.037 speed: 7319 wps
[train] 0.499 perplexity: 1.029 misclass:0.021 speed: 7304 wps
[train] 0.599 perplexity: 1.029 misclass:0.027 speed: 7278 wps
[train] 0.699 perplexity: 1.029 misclass:0.023 speed: 7305 wps
[train] 0.799 perplexity: 1.028 misclass:0.031 speed: 7325 wps
[train] 0.898 perplexity: 1.028 misclass:0.013 speed: 7351 wps
[train] 0.998 perplexity: 1.027 misclass:0.018 speed: 7350 wps
[test] 0.000 perplexity: 1.073 misclass:0.051 speed: 7636 wps
[test] 0.098 perplexity: 1.070 misclass:0.045 speed: 24717 wps
[test] 0.196 perplexity: 1.069 misclass:0.036 speed: 26122 wps
[test] 0.294 perplexity: 1.069 misclass:0.040 speed: 26648 wps
[test] 0.392 perplexity: 1.067 misclass:0.027 speed: 26311 wps
[test] 0.490 perplexity: 1.064 misclass:0.022 speed: 25776 wps
[test] 0.588 perplexity: 1.064 misclass:0.038 speed: 25694 wps
[test] 0.686 perplexity: 1.064 misclass:0.051 speed: 25628 wps
[test] 0.784 perplexity: 1.065 misclass:0.039 speed: 25791 wps
[test] 0.881 perplexity: 1.065 misclass:0.030 speed: 26011 wps
[test] 0.979 perplexity: 1.065 misclass:0.036 speed: 26214 wps
Saving

In [6]:
saver.restore(sess, "./data/bid3rnn_tagger.ckpt-9")
print("Best model restored!")


Best model restored!

In [7]:
smothered_logits = tf.boolean_mask(logits, tf.reshape(mask, [-1]))
smothered_targets = tf.reshape(tf.boolean_mask(targets, mask), [-1])

s = tf.nn.softmax(smothered_logits)
s_prob = tf.reduce_max(s, reduction_indices=[1], keep_dims=True)
kl_all = tf.log(len(reader.tag_to_id)*1.) + tf.reduce_sum(s * tf.log(tf.abs(s) + 1e-10),
                                                          reduction_indices=[1], keep_dims=True)
m_all, v_all = tf.nn.moments(kl_all, axes=[0])

logits_right = tf.boolean_mask(smothered_logits,
                               tf.equal(tf.argmax(smothered_logits, 1), smothered_targets))
s_right = tf.nn.softmax(logits_right)
s_right_prob = tf.reduce_max(s_right, reduction_indices=[1], keep_dims=True)
kl_right = tf.log(len(reader.tag_to_id)*1.) + tf.reduce_sum(s_right * tf.log(tf.abs(s_right) + 1e-10),
                                                            reduction_indices=[1], keep_dims=True)
m_right, v_right = tf.nn.moments(kl_right, axes=[0])

logits_wrong = tf.boolean_mask(smothered_logits,
                               tf.not_equal(tf.argmax(smothered_logits, 1), smothered_targets))
s_wrong = tf.nn.softmax(logits_wrong)
s_wrong_prob = tf.reduce_max(s_wrong, reduction_indices=[1], keep_dims=True)
kl_wrong = tf.log(len(reader.tag_to_id)*1.) + tf.reduce_sum(s_wrong * tf.log(tf.abs(s_wrong) + 1e-10),
                                                            reduction_indices=[1], keep_dims=True)
m_wrong, v_wrong = tf.nn.moments(kl_wrong, axes=[0])

In [8]:
err, kl_a, kl_r, kl_w, s_p, s_rp, s_wp = sess.run(
    [100*misclass, kl_all, kl_right, kl_wrong, s_prob, s_right_prob, s_wrong_prob],
    feed_dict={input_data: X_test, targets: Y_test, mask: mask_test})

print('WSJ Error (%)| Prediction Prob (mean, std) | PProb Right (mean, std) | PProb Wrong (mean, std):')
print(err, '|', np.mean(s_p), np.std(s_p), '|', np.mean(s_rp), np.std(s_rp), '|', np.mean(s_wp), np.std(s_wp))

print('\nSuccess Detection')
print('Success base rate (%):', round(100-err,2))
print('KL[p||u]: Right/Wrong classification distinction')
safe, risky = kl_r, kl_w
labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
labels[:safe.shape[0]] += 1
examples = np.squeeze(np.vstack((safe, risky)))
print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))

print('Prediction Prob: Right/Wrong classification distinction')
safe, risky = s_rp, s_wp
labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
labels[:safe.shape[0]] += 1
examples = np.squeeze(np.vstack((safe, risky)))
print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))


print('\nError Detection')
print('Error base rate (%):', round(err,2))
safe, risky = -kl_r, -kl_w
labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
labels[safe.shape[0]:] += 1
examples = np.squeeze(np.vstack((safe, risky)))
print('KL[p||u]: Right/Wrong classification distinction')
print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))

print('Prediction Prob: Right/Wrong classification distinction')
safe, risky = -s_rp, -s_wp
labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
labels[safe.shape[0]:] += 1
examples = np.squeeze(np.vstack((safe, risky)))
print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))


WSJ Error (%)| Prediction Prob (mean, std) | PProb Right (mean, std) | PProb Wrong (mean, std):
3.68122 | 0.976156 0.0895798 | 0.986341 0.0628749 | 0.709671 0.202042

Success Detection
Success base rate (%): 96.32
KL[p||u]: Right/Wrong classification distinction
AUPR (%): 99.8
AUROC (%): 95.92
Prediction Prob: Right/Wrong classification distinction
AUPR (%): 99.8
AUROC (%): 95.93

Error Detection
Error base rate (%): 3.68
KL[p||u]: Right/Wrong classification distinction
AUPR (%): 50.4
AUROC (%): 95.92
Prediction Prob: Right/Wrong classification distinction
AUPR (%): 50.9
AUROC (%): 95.93

In [12]:
def show_ood_detection_results(error_rate_for_in, in_examples, out_examples, out_mask):
    kl_oos, s_p_oos = sess.run([kl_all, s_prob], feed_dict={input_data: out_examples, mask: out_mask})

    print('OOD Example Prediction Probability (mean, std):')
    print(np.mean(s_p_oos), np.std(s_p_oos))

    print('\nNormality Detection')
    print('Normality base rate (%):', round(100*in_examples.shape[0]/(
                out_examples.shape[0] + in_examples.shape[0]),2))
    print('KL[p||u]: Normality Detection')
    safe, risky = kl_a, kl_oos
    labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
    labels[:safe.shape[0]] += 1
    examples = np.squeeze(np.vstack((safe, risky)))
    print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
    print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))

    print('Prediction Prob: Normality Detection')
    safe, risky = s_p, s_p_oos
    labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
    labels[:safe.shape[0]] += 1
    examples = np.squeeze(np.vstack((safe, risky)))
    print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
    print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))

    print('Normality base rate (%):', round(100*(1 - err/100)*in_examples.shape[0]/
          (out_examples.shape[0] + (1 - err/100)*in_examples.shape[0]),2))
    print('KL[p||u]: Normality Detection (relative to correct examples)')
    safe, risky = kl_r, kl_oos
    labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
    labels[:safe.shape[0]] += 1
    examples = np.squeeze(np.vstack((safe, risky)))
    print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
    print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))

    print('Prediction Prob: Normality Detection (relative to correct examples)')
    safe, risky = s_rp, s_p_oos
    labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
    labels[:safe.shape[0]] += 1
    examples = np.squeeze(np.vstack((safe, risky)))
    print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
    print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))


    print('\n\nAbnormality Detection')
    print('Abnormality base rate (%):', round(100*out_examples.shape[0]/(
                out_examples.shape[0] + in_examples.shape[0]),2))
    print('KL[p||u]: Abnormality Detection')
    safe, risky = -kl_a, -kl_oos
    labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
    labels[safe.shape[0]:] += 1
    examples = np.squeeze(np.vstack((safe, risky)))
    print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
    print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))

    print('Prediction Prob: Abnormality Detection')
    safe, risky = -s_p, -s_p_oos
    labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
    labels[safe.shape[0]:] += 1
    examples = np.squeeze(np.vstack((safe, risky)))
    print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
    print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))

    print('Abnormality base rate (%):', round(100*out_examples.shape[0]/
          (out_examples.shape[0] + (1 - err/100)*in_examples.shape[0]),2))
    print('KL[p||u]: Abnormality Detection (relative to correct examples)')
    safe, risky = -kl_r, -kl_oos
    labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
    labels[safe.shape[0]:] += 1
    examples = np.squeeze(np.vstack((safe, risky)))
    print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
    print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))

    print('Prediction Prob: Abnormality Detection (relative to correct examples)')
    safe, risky = -s_rp, -s_p_oos
    labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
    labels[safe.shape[0]:] += 1
    examples = np.squeeze(np.vstack((safe, risky)))
    print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
    print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))

In [10]:
reader.tag_to_id   # determine START, END, and PAD symbols from this; it's 0, 15, 16 in this run

In [11]:
def mask_for_data(_dataset, to_ignore=[0,15,16]):
    _mask = np.ones(_dataset.shape, dtype=np.bool)
    for tag_to_ignore in to_ignore:
        _mask = np.logical_and(_mask, _dataset != tag_to_ignore)
    return _mask

vocab = reader.word_to_id.keys()

# we replace <s> with </s> since it has no embedding, and </s> is a better embedding than UNK
xt, yt = data_to_mat('./data/Tweets/tweets-train.txt', vocab, reader.word_to_id,
                     start_tag=0, end_tag=15, pad_tag=16)
xdev, ydev = data_to_mat('./data/Tweets/tweets-dev.txt', vocab, reader.word_to_id,
                         start_tag=0, end_tag=15, pad_tag=16)
xdtest, ydtest = data_to_mat('./data/Tweets/tweets-devtest.txt', vocab, reader.word_to_id,
                             start_tag=0, end_tag=15, pad_tag=16)

tweets = {
    'x_train': xt, 'y_train': yt, 'train_mask': mask_for_data(yt),
    'x_dev': xdev, 'y_dev': ydev, 'dev_mask': mask_for_data(ydev),
    'x_devtest': xdtest, 'y_devtest': ydtest, 'devtest_mask': mask_for_data(ydtest),
}

In [13]:
print('Twitter OOD Detection\n')
show_ood_detection_results(err, X_test, tweets['x_devtest'], tweets['devtest_mask'])


Twitter OOD Detection

OOD Example Prediction Probability (mean, std):
0.80984 0.248264

Normality Detection
Normality base rate (%): 92.58
KL[p||u]: Normality Detection
AUPR (%): 97.62
AUROC (%): 78.17
Prediction Prob: Normality Detection
AUPR (%): 97.6
AUROC (%): 77.79
Normality base rate (%): 92.32
KL[p||u]: Normality Detection (relative to correct examples)
AUPR (%): 97.7
AUROC (%): 79.8
Prediction Prob: Normality Detection (relative to correct examples)
AUPR (%): 97.68
AUROC (%): 79.54


Abnormality Detection
Abnormality base rate (%): 7.42
KL[p||u]: Abnormality Detection
AUPR (%): 35.43
AUROC (%): 78.17
Prediction Prob: Normality Detection
AUPR (%): 31.14
AUROC (%): 77.79
Abnormality base rate (%): 7.68
KL[p||u]: Normality Detection (relative to correct examples)
AUPR (%): 44.27
AUROC (%): 79.8
Prediction Prob: Normality Detection (relative to correct examples)
AUPR (%): 40.8
AUROC (%): 79.54

English Web Treebanks


In [15]:
xtest, ytest = data_to_mat('./data/WebTreeBank/weblog_penntrees.test.conll', vocab,
                           reader.word_to_id, is_not_twitter=True, start_tag=0, end_tag=15, pad_tag=16)

In [17]:
print('Webblog OOD Detection\n')
show_ood_detection_results(err, X_test, xtest, mask_for_data(ytest))


Webblog OOD Detection

OOD Example Prediction Probability (mean, std):
0.933841 0.15627

Normality Detection
Normality base rate (%): 86.01
KL[p||u]: Normality Detection
AUPR (%): 87.88
AUROC (%): 59.84
Prediction Prob: Normality Detection
AUPR (%): 87.84
AUROC (%): 59.7
Normality base rate (%): 85.55
KL[p||u]: Normality Detection (relative to correct examples)
AUPR (%): 87.94
AUROC (%): 61.6
Prediction Prob: Normality Detection (relative to correct examples)
AUPR (%): 87.91
AUROC (%): 61.48


Abnormality Detection
Abnormality base rate (%): 13.99
KL[p||u]: Abnormality Detection
AUPR (%): 24.98
AUROC (%): 59.84
Prediction Prob: Normality Detection
AUPR (%): 23.91
AUROC (%): 59.7
Abnormality base rate (%): 14.45
KL[p||u]: Normality Detection (relative to correct examples)
AUPR (%): 30.57
AUROC (%): 61.6
Prediction Prob: Normality Detection (relative to correct examples)
AUPR (%): 29.52
AUROC (%): 61.48