In [1]:
import tensorflow as tf
import numpy as np
import re
import collections
import sklearn.metrics as sk

In [2]:
def load_data(filename='./data/r52-train.txt'):
    '''
    :param filename: the system location of the data to load
    :return: the text (x) and its label (y)
             the text is a list of words and is not processed
    '''

    topic_to_num = {"acq": 0,"alum": 1,"bop": 2,"carcass": 3,"cocoa": 4,"coffee": 5,"copper": 6,
                    "cotton": 7,"cpi": 8,"cpu": 9,"crude": 10,"dlr": 11,"earn": 12,"fuel": 13,"gas": 14,
                    "gnp": 15,"gold": 16,"grain": 17,"heat": 18,"housing": 19,"income": 20,"instal-debt": 21,
                    "interest": 22,"ipi": 23,"iron-steel": 24,"jet": 25,"jobs": 26,"lead": 27,"lei": 28,
                    "livestock": 29,"lumber": 30,"meal-feed": 31,"money-fx": 32,"money-supply": 33,
                    "nat-gas": 34,"nickel": 35,"orange": 36,"pet-chem": 37,"platinum": 38,"potato": 39,
                    "reserves": 40,"retail": 41,"rubber": 42,"ship": 43,"strategic-metal": 44,"sugar": 45,
                    "tea": 46,"tin": 47,"trade": 48,"veg-oil": 49,"wpi": 50,"zinc": 51}
    
    # stop words taken from nltk
    stop_words = ['i','me','my','myself','we','our','ours','ourselves','you','your','yours',
                  'yourself','yourselves','he','him','his','himself','she','her','hers','herself',
                  'it','its','itself','they','them','their','theirs','themselves','what','which',
                  'who','whom','this','that','these','those','am','is','are','was','were','be',
                  'been','being','have','has','had','having','do','does','did','doing','a','an',
                  'the','and','but','if','or','because','as','until','while','of','at','by','for',
                  'with','about','against','between','into','through','during','before','after',
                  'above','below','to','from','up','down','in','out','on','off','over','under',
                  'again','further','then','once','here','there','when','where','why','how','all',
                  'any','both','each','few','more','most','other','some','such','no','nor','not',
                  'only','own','same','so','than','too','very','s','t','can','will','just','don',
                  'should','now','d','ll','m','o','re','ve','y','ain','aren','couldn','didn',
                  'doesn','hadn','hasn','haven','isn','ma','mightn','mustn','needn','shan',
                  'shouldn','wasn','weren','won','wouldn']

    x, y = [], []
    with open(filename, "r") as f:
        for line in f:
            line = re.sub(r'\s+', ' ', line).strip()
            current_story = ' '.join(word for word in line.split(' ')[1:] if word not in stop_words)
            x.append(current_story)
            y.append(topic_to_num[line.split(' ')[0]])
    return x, np.array(y, dtype=int)

def get_vocab(dataset):
    '''
    :param dataset: the text from load_data

    :return: a _ordered_ dictionary from words to counts
    '''
    vocab = {}

    # create a counter for each word
    for example in dataset:
        example_as_list = example.split()
        for word in example_as_list:
            vocab[word] = 0

    for example in dataset:
        example_as_list = example.split()
        for word in example_as_list:
            vocab[word] += 1
    
    # sort from greatest to least by count
    return collections.OrderedDict(sorted(vocab.items(), key=lambda x: x[1], reverse=True))

def text_to_rank(dataset, _vocab, desired_vocab_size=1000):
    '''
    :param dataset: the text from load_data
    :vocab: a _ordered_ dictionary of vocab words and counts from get_vocab
    :param desired_vocab_size: the desired vocabulary size
    words no longer in vocab become UUUNNNKKK
    :return: the text corpus with words mapped to their vocab rank,
    with all sufficiently infrequent words mapped to UUUNNNKKK; UUUNNNKKK has rank desired_vocab_size
    (the infrequent word cutoff is determined by desired_vocab size)
    '''
    _dataset = dataset[:]     # aliasing safeguard
    vocab_ordered = list(_vocab)
    count_cutoff = _vocab[vocab_ordered[desired_vocab_size-2]] # get word by its rank and map to its count
    
    word_to_rank = {}
    for i in range(len(vocab_ordered)):
        # we add one to make room for any future padding symbol with value 0
        word_to_rank[vocab_ordered[i]] = i
    
    for i in range(len(_dataset)):
        example = _dataset[i]
        example_as_list = example.split()
        for j in range(len(example_as_list)):
            try:
                if _vocab[example_as_list[j]] >= count_cutoff and word_to_rank[example_as_list[j]] < desired_vocab_size:
                    # we need to ensure that other words below the word on the edge of our desired_vocab size
                    # are not also on the count cutoff
                    example_as_list[j] = word_to_rank[example_as_list[j]] 
                else:
                    example_as_list[j] = desired_vocab_size-1  # UUUNNNKKK
            except:
                example_as_list[j] = desired_vocab_size-1  # UUUNNNKKK
        _dataset[i] = example_as_list

    return _dataset

def text_to_matrix(dataset, _vocab, desired_vocab_size=1000):
    sequences = text_to_rank(dataset, _vocab, desired_vocab_size)
    
    mat = np.zeros((len(sequences), desired_vocab_size), dtype=int)
    
    for i, seq in enumerate(sequences):
        for token in seq:
            mat[i][token] = 1
    
    return mat

def get_vocab(dataset):
    '''
    :param dataset: the text from load_data

    :return: a _ordered_ dictionary from words to counts
    '''
    vocab = {}

    # create a counter for each word
    for example in dataset:
        example_as_list = example.split()
        for word in example_as_list:
            vocab[word] = 0

    for example in dataset:
        example_as_list = example.split()
        for word in example_as_list:
            vocab[word] += 1

    # sort from greatest to least by count
    return collections.OrderedDict(sorted(vocab.items(), key=lambda x: x[1], reverse=True))

In [3]:
def partion_data_in_two(dataset, dataset_labels, in_sample_labels, oos_labels):
    '''
    :param dataset: the text from text_to_rank
    :param dataset_labels: dataset labels
    :param in_sample_labels: a list of newsgroups which the network will/did train on
    :param oos_labels: the complement of in_sample_labels; these newsgroups the network has never seen
    :return: the dataset partitioned into in_sample_examples, in_sample_labels,
    oos_examples, and oos_labels in that order
    '''
    _dataset = dataset[:]     # aliasing safeguard
    _dataset_labels = dataset_labels
    
    in_sample_idxs = np.zeros(np.shape(_dataset_labels), dtype=bool)
    ones_vec = np.ones(np.shape(_dataset_labels), dtype=int)
    for label in in_sample_labels:
        in_sample_idxs = np.logical_or(in_sample_idxs, _dataset_labels == label * ones_vec)

    
    return _dataset[in_sample_idxs], _dataset_labels[in_sample_idxs],\
        _dataset[np.logical_not(in_sample_idxs)], _dataset_labels[np.logical_not(in_sample_idxs)]

In [4]:
# our network trains only on a subset of classes, say 6, but class number 7 might still
# be an in-sample label: we need to squish the labels to be in {0,...,5}
def relabel_in_sample_labels(labels):
    labels_as_list = labels.tolist()
    
    set_of_labels = []
    for label in labels_as_list:
        set_of_labels.append(label)
    labels_ordered = sorted(list(set(set_of_labels)))
    
    relabeled = np.zeros(labels.shape, dtype=int)
    for i in range(len(labels_as_list)):
        relabeled[i] = labels_ordered.index(labels_as_list[i])
    
    return relabeled

In [5]:
batch_size = 32
vocab_size = 1000
num_epochs = 5
n_hidden = 512
nclasses_to_exclude = 12  # 0-22

In [16]:
random_classes = np.arange(52)
np.random.shuffle(random_classes)
to_include = list(random_classes[:52-nclasses_to_exclude])
to_exclude = list(random_classes[52-nclasses_to_exclude:])

In [17]:
print('Loading Data')
X_train, Y_train = load_data()
X_test, Y_test = load_data('./data/r52-test.txt')

vocab = get_vocab(X_train)
X_train = text_to_matrix(X_train, vocab, vocab_size)
X_test = text_to_matrix(X_test, vocab, vocab_size)

# shuffle
indices = np.arange(X_train.shape[0])
np.random.shuffle(indices)
X_train = X_train[indices]
Y_train = Y_train[indices]

indices = np.arange(X_test.shape[0])
np.random.shuffle(indices)
X_test = X_test[indices]
Y_test = Y_test[indices]

in_sample_examples, in_sample_labels, oos_examples, oos_labels =\
partion_data_in_two(X_train, Y_train, to_include, to_exclude)
test_in_sample_examples, test_in_sample_labels, test_oos_examples, dev_oos_labels =\
partion_data_in_two(X_test, Y_test, to_include, to_exclude)

# safely assumes there is an example for each in_sample class in both the training and dev class 
in_sample_labels = relabel_in_sample_labels(in_sample_labels)
test_in_sample_labels = relabel_in_sample_labels(test_in_sample_labels)

num_examples = in_sample_labels.shape[0]
num_batches = num_examples//batch_size

print('Data loaded')


Loading Data
Data loaded

In [18]:
graph = tf.Graph()

with graph.as_default():
    x = tf.placeholder(dtype=tf.float32, shape=[None, vocab_size])
    y = tf.placeholder(dtype=tf.int64, shape=[None])
    is_training = tf.placeholder(tf.bool)
    
    # add one to vocab size for the padding symbol

    W_h = tf.Variable(tf.nn.l2_normalize(tf.random_normal([vocab_size, n_hidden]), 0)/tf.sqrt(1 + 0.45))
    b_h = tf.Variable(tf.zeros([n_hidden]))
    
    def gelu_fast(_x):
        return 0.5 * _x * (1 + tf.tanh(tf.sqrt(2 / np.pi) * (_x + 0.044715 * tf.pow(_x, 3))))
    
    h = tf.cond(is_training,
                lambda: tf.nn.dropout(gelu_fast(tf.matmul(x, W_h) + b_h), 0.5),
                lambda: gelu_fast(tf.matmul(x, W_h) + b_h))
    
    W_out = tf.Variable(tf.nn.l2_normalize(tf.random_normal([n_hidden, 52-nclasses_to_exclude]), 0)/tf.sqrt(0.45 + 1))
    b_out = tf.Variable(tf.zeros([52-nclasses_to_exclude]))
    
    logits = tf.matmul(h, W_out) + b_out
    
    loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits, y))

    global_step = tf.Variable(0, trainable=False)
    lr = tf.train.exponential_decay(1e-3, global_step, 4*num_batches, 0.1, staircase=True)
    optimizer = tf.train.AdamOptimizer(learning_rate=lr).minimize(loss, global_step=global_step)

    acc = 100*tf.reduce_mean(tf.to_float(tf.equal(tf.argmax(logits, 1), y)))

In [19]:
# initialize
sess = tf.InteractiveSession(graph=graph)
tf.initialize_all_variables().run()
# create saver to train model
saver = tf.train.Saver(max_to_keep=1)

print('Initialized')


Initialized

In [15]:
sess.close()

In [20]:
for epoch in range(num_epochs):
    # shuffle data every epoch
    indices = np.arange(num_examples)
    np.random.shuffle(indices)
    in_sample_examples = in_sample_examples[indices]
    in_sample_labels = in_sample_labels[indices]

    for i in range(num_batches):
        offset = i * batch_size

        x_batch = in_sample_examples[offset:offset + batch_size]
        y_batch = in_sample_labels[offset:offset + batch_size]

        _, l, batch_acc = sess.run([optimizer, loss, acc], feed_dict={x: x_batch, y: y_batch, is_training: True})

    print('Epoch %d | Minibatch loss %.3f | Minibatch accuracy %.3f' %
          (epoch+1, l, batch_acc))


Epoch 1 | Minibatch loss 0.526 | Minibatch accuracy 87.500
Epoch 2 | Minibatch loss 0.260 | Minibatch accuracy 93.750
Epoch 3 | Minibatch loss 0.106 | Minibatch accuracy 100.000
Epoch 4 | Minibatch loss 0.019 | Minibatch accuracy 100.000
Epoch 5 | Minibatch loss 0.052 | Minibatch accuracy 100.000

In [21]:
s = tf.nn.softmax(logits)
s_prob = tf.reduce_max(s, reduction_indices=[1], keep_dims=True)
kl_all = tf.log(52. - nclasses_to_exclude)\
        + tf.reduce_sum(s * tf.log(tf.abs(s) + 1e-11), reduction_indices=[1], keep_dims=True)
m_all, v_all = tf.nn.moments(kl_all, axes=[0])

logits_right = tf.boolean_mask(logits, tf.equal(tf.argmax(logits, 1), y))
s_right = tf.nn.softmax(logits_right)
s_right_prob = tf.reduce_max(s_right, reduction_indices=[1], keep_dims=True)
kl_right = tf.log(52. - nclasses_to_exclude)\
         + tf.reduce_sum(s_right * tf.log(tf.abs(s_right) + 1e-11), reduction_indices=[1], keep_dims=True)
m_right, v_right = tf.nn.moments(kl_right, axes=[0])

logits_wrong = tf.boolean_mask(logits, tf.not_equal(tf.argmax(logits, 1), y))
s_wrong = tf.nn.softmax(logits_wrong)
s_wrong_prob = tf.reduce_max(s_wrong, reduction_indices=[1], keep_dims=True)
kl_wrong = tf.log(52. - nclasses_to_exclude)\
           + tf.reduce_sum(s_wrong * tf.log(tf.abs(s_wrong) + 1e-11), reduction_indices=[1], keep_dims=True)
m_wrong, v_wrong = tf.nn.moments(kl_wrong, axes=[0])

In [22]:
err, kl_a, kl_r, kl_w, s_p, s_rp, s_wp = sess.run(
    [100 - acc, kl_all, kl_right, kl_wrong, s_prob, s_right_prob, s_wrong_prob],
    feed_dict={x: test_in_sample_examples, y: test_in_sample_labels, is_training: False})

print('Reuters52 (w/class subset) Error (%)| Prediction Prob (mean, std) | PProb Right (mean, std) | PProb Wrong (mean, std):')
print(err, '|', np.mean(s_p), np.std(s_p), '|', np.mean(s_rp), np.std(s_rp), '|', np.mean(s_wp), np.std(s_wp))

print('\nSuccess Detection')
print('Success base rate (%):', round(100-err,2))
print('KL[p||u]: Right/Wrong classification distinction')
safe, risky = kl_r, kl_w
labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
labels[:safe.shape[0]] += 1
examples = np.squeeze(np.vstack((safe, risky)))
print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))

print('Prediction Prob: Right/Wrong classification distinction')
safe, risky = s_rp, s_wp
labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
labels[:safe.shape[0]] += 1
examples = np.squeeze(np.vstack((safe, risky)))
print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))


print('\nError Detection')
print('Error base rate (%):', round(err,2))
safe, risky = -kl_r, -kl_w
labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
labels[safe.shape[0]:] += 1
examples = np.squeeze(np.vstack((safe, risky)))
print('KL[p||u]: Right/Wrong classification distinction')
print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))

print('Prediction Prob: Right/Wrong classification distinction')
safe, risky = -s_rp, -s_wp
labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
labels[safe.shape[0]:] += 1
examples = np.squeeze(np.vstack((safe, risky)))
print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))


Reuters52 (w/class subset) Error (%)| Prediction Prob (mean, std) | PProb Right (mean, std) | PProb Wrong (mean, std):
7.93185 | 0.908305 0.205085 | 0.940961 0.157806 | 0.529264 0.291815

Success Detection
Success base rate (%): 92.07
KL[p||u]: Right/Wrong classification distinction
AUPR (%): 99.18
AUROC (%): 91.47
Prediction Prob: Right/Wrong classification distinction
AUPR (%): 99.28
AUROC (%): 92.53

Error Detection
Error base rate (%): 7.93
KL[p||u]: Right/Wrong classification distinction
AUPR (%): 45.17
AUROC (%): 91.47
Prediction Prob: Right/Wrong classification distinction
AUPR (%): 51.85
AUROC (%): 92.53

In [23]:
def show_ood_detection_results(error_rate_for_in, in_examples, out_examples):
    kl_oos, s_p_oos = sess.run([kl_all, s_prob], feed_dict={x: out_examples, is_training: False})

    print('OOD Example Prediction Probability (mean, std):')
    print(np.mean(s_p_oos), np.std(s_p_oos))

    print('\nNormality Detection')
    print('Normality base rate (%):', round(100*in_examples.shape[0]/(
                out_examples.shape[0] + in_examples.shape[0]),2))
    print('KL[p||u]: Normality Detection')
    safe, risky = kl_a, kl_oos
    labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
    labels[:safe.shape[0]] += 1
    examples = np.squeeze(np.vstack((safe, risky)))
    print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
    print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))

    print('Prediction Prob: Normality Detection')
    safe, risky = s_p, s_p_oos
    labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
    labels[:safe.shape[0]] += 1
    examples = np.squeeze(np.vstack((safe, risky)))
    print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
    print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))

    print('Normality base rate (%):', round(100*(1 - err/100)*in_examples.shape[0]/
          (out_examples.shape[0] + (1 - err/100)*in_examples.shape[0]),2))
    print('KL[p||u]: Normality Detection (relative to correct examples)')
    safe, risky = kl_r, kl_oos
    labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
    labels[:safe.shape[0]] += 1
    examples = np.squeeze(np.vstack((safe, risky)))
    print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
    print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))

    print('Prediction Prob: Normality Detection (relative to correct examples)')
    safe, risky = s_rp, s_p_oos
    labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
    labels[:safe.shape[0]] += 1
    examples = np.squeeze(np.vstack((safe, risky)))
    print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
    print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))


    print('\n\nAbnormality Detection')
    print('Abnormality base rate (%):', round(100*out_examples.shape[0]/(
                out_examples.shape[0] + in_examples.shape[0]),2))
    print('KL[p||u]: Abnormality Detection')
    safe, risky = -kl_a, -kl_oos
    labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
    labels[safe.shape[0]:] += 1
    examples = np.squeeze(np.vstack((safe, risky)))
    print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
    print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))

    print('Prediction Prob: Normality Detection')
    safe, risky = -s_p, -s_p_oos
    labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
    labels[safe.shape[0]:] += 1
    examples = np.squeeze(np.vstack((safe, risky)))
    print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
    print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))

    print('Abnormality base rate (%):', round(100*out_examples.shape[0]/
          (out_examples.shape[0] + (1 - err/100)*in_examples.shape[0]),2))
    print('KL[p||u]: Normality Detection (relative to correct examples)')
    safe, risky = -kl_r, -kl_oos
    labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
    labels[safe.shape[0]:] += 1
    examples = np.squeeze(np.vstack((safe, risky)))
    print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
    print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))

    print('Prediction Prob: Normality Detection (relative to correct examples)')
    safe, risky = -s_rp, -s_p_oos
    labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
    labels[safe.shape[0]:] += 1
    examples = np.squeeze(np.vstack((safe, risky)))
    print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
    print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))

In [24]:
print('Held-out subjects\n')
show_ood_detection_results(err, test_in_sample_examples, test_oos_examples)


Held-out subjects

OOD Example Prediction Probability (mean, std):
0.687292 0.300118

Normality Detection
Normality base rate (%): 66.28
KL[p||u]: Normality Detection
AUPR (%): 90.86
AUROC (%): 82.14
Prediction Prob: Normality Detection
AUPR (%): 90.77
AUROC (%): 81.7
Normality base rate (%): 64.41
KL[p||u]: Normality Detection (relative to correct examples)
AUPR (%): 92.07
AUROC (%): 85.91
Prediction Prob: Normality Detection (relative to correct examples)
AUPR (%): 92.09
AUROC (%): 85.79


Abnormality Detection
Abnormality base rate (%): 33.72
KL[p||u]: Abnormality Detection
AUPR (%): 62.47
AUROC (%): 82.14
Prediction Prob: Normality Detection
AUPR (%): 61.98
AUROC (%): 81.7
Abnormality base rate (%): 35.59
KL[p||u]: Normality Detection (relative to correct examples)
AUPR (%): 73.02
AUROC (%): 85.91
Prediction Prob: Normality Detection (relative to correct examples)
AUPR (%): 73.61
AUROC (%): 85.79