In [1]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/")
import tensorflow as tf
import numpy as np
import sys
import os
import pickle
from load_cifar10 import load_data10
import sklearn.metrics as sk

# training parameters
learning_rate = 0.001
training_epochs = 30
batch_size = 128

# architecture parameters
n_labels = 10
image_pixels = 28 * 28
bottleneck = 128


Extracting /tmp/data/train-images-idx3-ubyte.gz
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz

In [2]:
def add_noise(batch, complexity=0.5):
    return batch + np.random.normal(size=batch.shape, scale=1e-9 + complexity)

from skimage.filters import gaussian
def blur(img, complexity=0.5):
    image = img.reshape((-1, 28, 28))
    return gaussian(image, sigma=5*complexity).reshape((-1, 28*28))

In [3]:
graph = tf.Graph()
with graph.as_default():
    x = tf.placeholder(dtype=tf.float32, shape=[None, image_pixels])
    y = tf.placeholder(dtype=tf.int64, shape=[None])
    risk_labels = tf.placeholder(dtype=tf.float32, shape=[None])
    is_unfrozen = tf.placeholder(tf.bool)


    def gelu_fast(x):
        return 0.5 * x * (1 + tf.tanh(tf.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))
    rho = gelu_fast

    # if mode == 'input_restricted':
    W = {}
    b = {}

    with tf.variable_scope("in_sample"):
        W['1'] = tf.Variable(tf.nn.l2_normalize(tf.random_normal([image_pixels, 256]), 0))
        W['2'] = tf.Variable(tf.nn.l2_normalize(tf.random_normal([256, 256]), 0))
        W['3'] = tf.Variable(tf.nn.l2_normalize(tf.random_normal([256, 256]), 0))
        W['logits'] = tf.Variable(tf.nn.l2_normalize(tf.random_normal([256, n_labels]), 0))
        W['bottleneck'] = tf.Variable(tf.nn.l2_normalize(tf.random_normal([256, bottleneck]), 0))
        W['decode1'] = tf.Variable(tf.nn.l2_normalize(tf.random_normal([bottleneck, 256]), 0))
        W['decode2'] = tf.Variable(tf.nn.l2_normalize(tf.random_normal([256, 256]), 0))
        W['reconstruction'] = tf.Variable(tf.nn.l2_normalize(tf.random_normal([256, image_pixels]), 0))

        b['1'] = tf.Variable(tf.zeros([256]))
        b['2'] = tf.Variable(tf.zeros([256]))
        b['3'] = tf.Variable(tf.zeros([256]))
        b['logits'] = tf.Variable(tf.zeros([n_labels]))
        b['bottleneck'] = tf.Variable(tf.zeros([bottleneck]))
        b['decode1'] = tf.Variable(tf.zeros([256]))
        b['decode2'] = tf.Variable(tf.zeros([256]))
        b['reconstruction'] = tf.Variable(tf.zeros([image_pixels]))

    with tf.variable_scope("out_of_sample"):
        W['residual_to_risk1'] = tf.Variable(tf.nn.l2_normalize(tf.random_normal([image_pixels, 512]), 0))
        W['hidden_to_risk1'] = tf.Variable(tf.nn.l2_normalize(tf.random_normal([256, 512]), 0))
        W['logits_to_risk1'] = tf.Variable(tf.nn.l2_normalize(tf.random_normal([n_labels, 512]), 0))
        W['risk2'] = tf.Variable(tf.nn.l2_normalize(tf.random_normal([512, 128]), 0))
        W['risk'] = tf.Variable(tf.nn.l2_normalize(tf.random_normal([128, 1]), 0))

        b['risk1'] = tf.Variable(tf.zeros([512]))
        b['risk2'] = tf.Variable(tf.zeros([128]))
        b['risk'] = tf.Variable(tf.zeros([1]))

    def risk_net(x):
        h1 = rho(tf.matmul(x, W['1']) + b['1'])
        h2 = rho(tf.matmul(h1, W['2']) + b['2'])
        h3 = rho(tf.matmul(h2, W['3']) + b['3'])
        logits_out = tf.matmul(h3, W['logits']) + b['logits']

        hidden_to_bottleneck = rho(tf.matmul(h2, W['bottleneck']) + b['bottleneck'])
        d1 = rho(tf.matmul(hidden_to_bottleneck, W['decode1']) + b['decode1'])
        d2 = rho(tf.matmul(d1, W['decode2']) + b['decode2'])
        recreation = tf.matmul(d2, W['reconstruction']) + b['reconstruction']

        risk1 = rho(tf.matmul(logits_out, W['logits_to_risk1']) +
                    tf.matmul(tf.square(x - recreation), W['residual_to_risk1']) +
                    tf.matmul(h2, W['hidden_to_risk1']) + b['risk1'])
        risk2 = rho(tf.matmul(risk1, W['risk2']) + b['risk2'])
        risk_out = tf.matmul(risk2, W['risk'])

        return logits_out, recreation, tf.squeeze(risk_out)

    logits, reconstruction, risk = risk_net(x)
    ce = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits, y))
    rec_error = tf.reduce_mean(tf.square(x - reconstruction))
    loss = 0.9 * ce + 0.1 * rec_error
    lr = tf.constant(0.001)
    optimizer = tf.train.AdamOptimizer(learning_rate=lr).minimize(loss)
    
    compute_error = 100*tf.reduce_mean(tf.to_float(tf.not_equal(tf.argmax(logits, 1), y)))
    compute_risk_error = 100*tf.reduce_mean(tf.to_float(tf.not_equal(tf.to_int64(tf.round(tf.sigmoid(risk))),
                                                                     tf.to_int64(tf.round(risk_labels)))))

In [7]:
sess.close()

In [4]:
sess = tf.InteractiveSession(graph=graph)
print('Beginning training: Phase 1')

# Adam requires special care
in_sample_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "in_sample")
out_of_sample_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "out_of_sample")
sess.run(tf.initialize_variables(set(tf.all_variables()) - set(out_of_sample_vars)))

num_batches = int(mnist.train.num_examples / batch_size)
save_every = int(num_batches/3.1)      # save training information 3 times per epoch
ce_ema = 2.3            # - log(0.1)
err_ema = 0.9
risk_loss_ema = 0.3     # - log(0.5)
learning_rate = 0.001
for epoch in range(training_epochs):
    if epoch >= 20:
        learning_rate = 0.0001
    for i in range(num_batches):
        bx, by = mnist.train.next_batch(batch_size)
        _, err, l = sess.run([optimizer, compute_error, ce], feed_dict={x: bx, y: by, is_unfrozen: True,
                                                                        lr: learning_rate})
        ce_ema = ce_ema * 0.95 + 0.05 * l
        err_ema = err_ema * 0.95 + 0.05 * err


Beginning training: Phase 1

In [5]:
print('Entering out-of-distribution training phase')

# Adam requires special care
risk_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(risk, risk_labels))
phase2_vars = list(set(tf.all_variables()) - set(in_sample_vars))
curr_lr = learning_rate
risk_optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(risk_loss, var_list=phase2_vars)
sess.run(tf.initialize_variables(set(tf.all_variables()) - set(in_sample_vars)))

err_ema = 50
for epoch in range(15):
    for i in range(num_batches):
        offset = i * batch_size
        bx, by = mnist.train.next_batch(1.5*batch_size)

        bx1 = bx[0:batch_size//6]
        bx2 = mnist.validation.next_batch(batch_size//3)[0][batch_size//6:batch_size//3]
        distortion = np.random.uniform(low=0.9, high=1.2)
        bx3 = add_noise(bx[batch_size//3:4*batch_size//6], complexity=distortion)
        distortion = np.random.uniform(low=0.9, high=1)
        bx4 = blur(bx[4*batch_size//6:5*batch_size//6], complexity=distortion)
        bx5 = np.zeros(shape=(batch_size - 5*batch_size//6, 28*28))
        for k in range(5*batch_size//6, batch_size):
            if by[k] == 0:
                bx5[k - 5*batch_size//6] = add_noise(bx[k], complexity=1).reshape((28*28,))
            else:
                bx5[k - 5*batch_size//6] = np.rot90(bx[k].reshape((28,28)),
                                                    k=np.random.choice([1,3])).reshape((28*28,))
        risks = np.zeros(batch_size)
        risks[:batch_size//3] = 1
        bx = np.vstack((bx1, bx2, bx3, bx4, bx5))

        _, rl, err = sess.run([risk_optimizer, risk_loss, compute_risk_error],
                              feed_dict={x: bx, risk_labels: risks, is_unfrozen: False})
        risk_loss_ema = risk_loss_ema * 0.95 + 0.05 * rl
        err_ema = err_ema * 0.95 + 0.05 * err

    print('Epoch:', epoch, '|', 'ema of risk for epoch:', risk_loss_ema, 'error (%):', err_ema)


Entering out-of-distribution training phase
/scratch/Software/anaconda3/envs/tensorflow/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:162: VisibleDeprecationWarning: using a non-integer number instead of an integer will result in an error in the future
  return self._images[start:end], self._labels[start:end]
Epoch: 0 | ema of risk for epoch: 0.0289991779286 error (%): 1.02815375856
Epoch: 1 | ema of risk for epoch: 0.0235179694132 error (%): 0.796505824561
Epoch: 2 | ema of risk for epoch: 0.0173152953255 error (%): 0.727457317635
Epoch: 3 | ema of risk for epoch: 0.0137615267565 error (%): 0.513664685183
Epoch: 4 | ema of risk for epoch: 0.0132930072201 error (%): 0.45190325967
Epoch: 5 | ema of risk for epoch: 0.0176867936823 error (%): 0.733053278535
Epoch: 6 | ema of risk for epoch: 0.0082577667913 error (%): 0.248823425983
Epoch: 7 | ema of risk for epoch: 0.0105122508973 error (%): 0.283814799134
Epoch: 8 | ema of risk for epoch: 0.011084522913 error (%): 0.449430105553
Epoch: 9 | ema of risk for epoch: 0.0104165609187 error (%): 0.342036450263
Epoch: 10 | ema of risk for epoch: 0.00670572366864 error (%): 0.227839521794
Epoch: 11 | ema of risk for epoch: 0.00618380575562 error (%): 0.24187715303
Epoch: 12 | ema of risk for epoch: 0.00836288376746 error (%): 0.214100012671
Epoch: 13 | ema of risk for epoch: 0.0107128988622 error (%): 0.329135162083
Epoch: 14 | ema of risk for epoch: 0.00716427481169 error (%): 0.237637717177

In [6]:
# load notMNIST, CIFAR-10, and Omniglot
pickle_file = './data/notMNIST.pickle'
with open(pickle_file, 'rb') as f:
    save = pickle.load(f, encoding='latin1')
    notmnist_dataset = save['test_dataset'].reshape((-1, 28 * 28))
    del save

_, _, X_test, _ = load_data10()
cifar_batch = sess.run(tf.image.resize_images(tf.image.rgb_to_grayscale(X_test), 28, 28))

import scipy.io as sio
import scipy.misc as scimisc
# other alphabets have characters which overlap
safe_list = [0,2,5,6,8,12,13,14,15,16,17,18,19,21,26]
m = sio.loadmat("./data/data_background.mat")

squished_set = []
for safe_number in safe_list:
    for alphabet in m['images'][safe_number]:
        for letters in alphabet:
            for letter in letters:
                for example in letter:
                    squished_set.append(scimisc.imresize(1 - example[0], (28,28)).reshape(1, 28*28))

omni_images = np.concatenate(squished_set, axis=0)

In [7]:
err, r_err, r, conf  = sess.run([compute_error, compute_risk_error, tf.sigmoid(risk), tf.nn.softmax(logits)],
                          feed_dict={x: mnist.test.images, y: mnist.test.labels, risk_labels: np.ones(10000)})

r_right = r[np.argmax(conf, axis=1).astype(np.uint8) == mnist.test.labels]

print('MNIST Digit Error (%) | MNIST Riskiness Error (%) | Digit Confidence (mean, std):')
print(err, '|', r_err, '|', np.mean(np.max(conf, axis=1)), np.std(np.max(conf, axis=1)))


MNIST Digit Error (%) | MNIST Riskiness Error (%) | Digit Confidence (mean, std):
1.46 | 1.07 | 0.995995 0.0350266

In [9]:
s = tf.nn.softmax(logits)
s_prob = tf.reduce_max(s, reduction_indices=[1], keep_dims=True)
kl_all = tf.log(10.) + tf.reduce_sum(s * tf.log(tf.abs(s) + 1e-11), reduction_indices=[1], keep_dims=True)
m_all, v_all = tf.nn.moments(kl_all, axes=[0])

logits_right = tf.boolean_mask(logits, tf.equal(tf.argmax(logits, 1), y))
s_right = tf.nn.softmax(logits_right)
s_right_prob = tf.reduce_max(s_right, reduction_indices=[1], keep_dims=True)
kl_right = tf.log(10.) + tf.reduce_sum(s_right * tf.log(tf.abs(s_right) + 1e-11),
                                       reduction_indices=[1], keep_dims=True)
m_right, v_right = tf.nn.moments(kl_right, axes=[0])

logits_wrong = tf.boolean_mask(logits, tf.not_equal(tf.argmax(logits, 1), y))
s_wrong = tf.nn.softmax(logits_wrong)
s_wrong_prob = tf.reduce_max(s_wrong, reduction_indices=[1], keep_dims=True)
kl_wrong = tf.log(10.) + tf.reduce_sum(s_wrong * tf.log(tf.abs(s_wrong) + 1e-11),
                                       reduction_indices=[1], keep_dims=True)
m_wrong, v_wrong = tf.nn.moments(kl_wrong, axes=[0])

In [12]:
kl_a, kl_r, kl_w, s_p, s_rp, s_wp = sess.run(
    [kl_all, kl_right, kl_wrong, s_prob, s_right_prob, s_wrong_prob],
    feed_dict={x: mnist.test.images, y: mnist.test.labels})

print('\nSuccess Detection')
print('Success base rate (%):', round(100-err,2))
print('KL[p||u]: Right/Wrong classification distinction')
safe, risky = kl_r, kl_w
labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
labels[:safe.shape[0]] += 1
examples = np.squeeze(np.vstack((safe, risky)))
print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))

print('Prediction Prob: Right/Wrong classification distinction')
safe, risky = s_rp, s_wp
labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
labels[:safe.shape[0]] += 1
examples = np.squeeze(np.vstack((safe, risky)))
print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))


print('\nError Detection')
print('Error base rate (%):', round(err,2))
safe, risky = -kl_r, -kl_w
labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
labels[safe.shape[0]:] += 1
examples = np.squeeze(np.vstack((safe, risky)))
print('KL[p||u]: Right/Wrong classification distinction')
print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))

print('Prediction Prob: Right/Wrong classification distinction')
safe, risky = -s_rp, -s_wp
labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
labels[safe.shape[0]:] += 1
examples = np.squeeze(np.vstack((safe, risky)))
print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))


Success Detection
Success base rate (%): 98.54
KL[p||u]: Right/Wrong classification distinction
AUPR (%): 99.96
AUROC (%): 97.45
Prediction Prob: Right/Wrong classification distinction
AUPR (%): 99.95
AUROC (%): 96.67

Error Detection
Error base rate (%): 1.46
KL[p||u]: Right/Wrong classification distinction
AUPR (%): 39.21
AUROC (%): 97.45
Prediction Prob: Right/Wrong classification distinction
AUPR (%): 39.33
AUROC (%): 96.67

In [17]:
def show_ood_detection_results_softmax(error_rate_for_in, in_examples, out_examples):
    kl_oos, s_p_oos = sess.run([kl_all, s_prob], feed_dict={x: out_examples})

    print('OOD Example Prediction Probability (mean, std):')
    print(np.mean(s_p_oos), np.std(s_p_oos))

    print('\nNormality Detection')
    print('Normality base rate (%):', round(100*in_examples.shape[0]/(
                out_examples.shape[0] + in_examples.shape[0]),2))
    print('KL[p||u]: Normality Detection')
    safe, risky = kl_a, kl_oos
    labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
    labels[:safe.shape[0]] += 1
    examples = np.squeeze(np.vstack((safe, risky)))
    print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
    print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))

    print('Prediction Prob: Normality Detection')
    safe, risky = s_p, s_p_oos
    labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
    labels[:safe.shape[0]] += 1
    examples = np.squeeze(np.vstack((safe, risky)))
    print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
    print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))

    print('Normality base rate (%):', round(100*(1 - err/100)*in_examples.shape[0]/
          (out_examples.shape[0] + (1 - err/100)*in_examples.shape[0]),2))
    print('KL[p||u]: Normality Detection (relative to correct examples)')
    safe, risky = kl_r, kl_oos
    labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
    labels[:safe.shape[0]] += 1
    examples = np.squeeze(np.vstack((safe, risky)))
    print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
    print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))

    print('Prediction Prob: Normality Detection (relative to correct examples)')
    safe, risky = s_rp, s_p_oos
    labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
    labels[:safe.shape[0]] += 1
    examples = np.squeeze(np.vstack((safe, risky)))
    print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
    print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))


    print('\n\nAbnormality Detection')
    print('Abnormality base rate (%):', round(100*out_examples.shape[0]/(
                out_examples.shape[0] + in_examples.shape[0]),2))
    print('KL[p||u]: Abnormality Detection')
    safe, risky = -kl_a, -kl_oos
    labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
    labels[safe.shape[0]:] += 1
    examples = np.squeeze(np.vstack((safe, risky)))
    print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
    print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))

    print('Prediction Prob: Abnormality Detection')
    safe, risky = -s_p, -s_p_oos
    labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
    labels[safe.shape[0]:] += 1
    examples = np.squeeze(np.vstack((safe, risky)))
    print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
    print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))

    print('Abnormality base rate (%):', round(100*out_examples.shape[0]/
          (out_examples.shape[0] + (1 - err/100)*in_examples.shape[0]),2))
    print('KL[p||u]: Abnormality Detection (relative to correct examples)')
    safe, risky = -kl_r, -kl_oos
    labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
    labels[safe.shape[0]:] += 1
    examples = np.squeeze(np.vstack((safe, risky)))
    print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
    print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))

    print('Prediction Prob: Abnormality Detection (relative to correct examples)')
    safe, risky = -s_rp, -s_p_oos
    labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
    labels[safe.shape[0]:] += 1
    examples = np.squeeze(np.vstack((safe, risky)))
    print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
    print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))

In [18]:
print('Omniglot (Softmax version)\n')
show_ood_detection_results_softmax(err, mnist.test.images, omni_images.reshape(-1, 28*28))


Omniglot (Softmax version)

OOD Example Prediction Probability (mean, std):
0.878658 0.174975

Normality Detection
Normality base rate (%): 52.08
KL[p||u]: Normality Detection
AUPR (%): 95.47
AUROC (%): 94.75
Prediction Prob: Normality Detection
AUPR (%): 95.01
AUROC (%): 94.12
Normality base rate (%): 51.72
KL[p||u]: Normality Detection (relative to correct examples)
AUPR (%): 95.85
AUROC (%): 95.43
Prediction Prob: Normality Detection (relative to correct examples)
AUPR (%): 95.39
AUROC (%): 94.81


Abnormality Detection
Abnormality base rate (%): 47.92
KL[p||u]: Abnormality Detection
AUPR (%): 93.89
AUROC (%): 94.75
Prediction Prob: Abnormality Detection
AUPR (%): 93.39
AUROC (%): 94.12
Abnormality base rate (%): 48.28
KL[p||u]: Abnormality Detection (relative to correct examples)
AUPR (%): 95.2
AUROC (%): 95.43
Prediction Prob: Abnormality Detection (relative to correct examples)
AUPR (%): 94.83
AUROC (%): 94.81

In [19]:
print('notMNIST (Softmax version)\n')
show_ood_detection_results_softmax(err, mnist.test.images, notmnist_dataset)


notMNIST (Softmax version)

OOD Example Prediction Probability (mean, std):
0.917525 0.14954

Normality Detection
Normality base rate (%): 50.0
KL[p||u]: Normality Detection
AUPR (%): 87.12
AUROC (%): 86.69
Prediction Prob: Normality Detection
AUPR (%): 87.5
AUROC (%): 86.1
Normality base rate (%): 49.63
KL[p||u]: Normality Detection (relative to correct examples)
AUPR (%): 87.49
AUROC (%): 87.44
Prediction Prob: Normality Detection (relative to correct examples)
AUPR (%): 87.87
AUROC (%): 86.85


Abnormality Detection
Abnormality base rate (%): 50.0
KL[p||u]: Abnormality Detection
AUPR (%): 88.22
AUROC (%): 86.69
Prediction Prob: Abnormality Detection
AUPR (%): 88.01
AUROC (%): 86.1
Abnormality base rate (%): 50.37
KL[p||u]: Abnormality Detection (relative to correct examples)
AUPR (%): 89.77
AUROC (%): 87.44
Prediction Prob: Abnormality Detection (relative to correct examples)
AUPR (%): 89.67
AUROC (%): 86.85

In [20]:
print('CIFAR-10bw (Softmax Version)\n')
show_ood_detection_results_softmax(err, mnist.test.images, cifar_batch.reshape(-1, 28*28))


CIFAR-10bw (Softmax Version)

OOD Example Prediction Probability (mean, std):
0.792851 0.217157

Normality Detection
Normality base rate (%): 50.0
KL[p||u]: Normality Detection
AUPR (%): 97.77
AUROC (%): 97.54
Prediction Prob: Normality Detection
AUPR (%): 97.57
AUROC (%): 97.26
Normality base rate (%): 49.63
KL[p||u]: Normality Detection (relative to correct examples)
AUPR (%): 98.09
AUROC (%): 98.02
Prediction Prob: Normality Detection (relative to correct examples)
AUPR (%): 97.91
AUROC (%): 97.78


Abnormality Detection
Abnormality base rate (%): 50.0
KL[p||u]: Abnormality Detection
AUPR (%): 97.35
AUROC (%): 97.54
Prediction Prob: Abnormality Detection
AUPR (%): 96.96
AUROC (%): 97.26
Abnormality base rate (%): 50.37
KL[p||u]: Abnormality Detection (relative to correct examples)
AUPR (%): 98.04
AUROC (%): 98.02
Prediction Prob: Abnormality Detection (relative to correct examples)
AUPR (%): 97.81
AUROC (%): 97.78

In [21]:
print('Sheer White Gaussian Noise (Softmax Version)\n')
show_ood_detection_results_softmax(err, mnist.test.images, np.random.normal(size=(10000, 28*28)))


Sheer White Gaussian Noise (Softmax Version)

OOD Example Prediction Probability (mean, std):
0.925578 0.13632

Normality Detection
Normality base rate (%): 50.0
KL[p||u]: Normality Detection
AUPR (%): 87.93
AUROC (%): 87.52
Prediction Prob: Normality Detection
AUPR (%): 88.06
AUROC (%): 86.78
Normality base rate (%): 49.63
KL[p||u]: Normality Detection (relative to correct examples)
AUPR (%): 88.32
AUROC (%): 88.29
Prediction Prob: Normality Detection (relative to correct examples)
AUPR (%): 88.45
AUROC (%): 87.54


Abnormality Detection
Abnormality base rate (%): 50.0
KL[p||u]: Abnormality Detection
AUPR (%): 88.46
AUROC (%): 87.52
Prediction Prob: Abnormality Detection
AUPR (%): 88.21
AUROC (%): 86.78
Abnormality base rate (%): 50.37
KL[p||u]: Abnormality Detection (relative to correct examples)
AUPR (%): 90.15
AUROC (%): 88.29
Prediction Prob: Abnormality Detection (relative to correct examples)
AUPR (%): 89.97
AUROC (%): 87.54

In [22]:
print('Sheer Uniform Noise (Softmax Version)\n')
show_ood_detection_results_softmax(err, mnist.test.images, np.random.uniform(size=(10000, 28*28)))


Sheer Uniform Noise (Softmax Version)

OOD Example Prediction Probability (mean, std):
0.706721 0.229194

Normality Detection
Normality base rate (%): 50.0
KL[p||u]: Normality Detection
AUPR (%): 99.36
AUROC (%): 99.19
Prediction Prob: Normality Detection
AUPR (%): 99.27
AUROC (%): 99.02
Normality base rate (%): 49.63
KL[p||u]: Normality Detection (relative to correct examples)
AUPR (%): 99.57
AUROC (%): 99.48
Prediction Prob: Normality Detection (relative to correct examples)
AUPR (%): 99.51
AUROC (%): 99.37


Abnormality Detection
Abnormality base rate (%): 50.0
KL[p||u]: Abnormality Detection
AUPR (%): 99.03
AUROC (%): 99.19
Prediction Prob: Abnormality Detection
AUPR (%): 98.68
AUROC (%): 99.02
Abnormality base rate (%): 50.37
KL[p||u]: Abnormality Detection (relative to correct examples)
AUPR (%): 99.41
AUROC (%): 99.48
Prediction Prob: Abnormality Detection (relative to correct examples)
AUPR (%): 99.22
AUROC (%): 99.37

In [23]:
def show_ood_detection_results(error_rate_for_in, in_examples, out_examples):
    r_oos, conf = sess.run([tf.sigmoid(risk), tf.nn.softmax(logits)], feed_dict={x: out_examples})
    
    print('OOD Example Prediction Probability (mean, std):')
    print(np.mean(np.max(conf, axis=1)), np.std(np.max(conf, axis=1)))

    print('\nNormality Detection')
    print('Normality base rate (%):', round(100*in_examples.shape[0]/(
                out_examples.shape[0] + in_examples.shape[0]),2))
    print('Normality Detection')
    safe, risky = r.reshape(-1,1), r_oos.reshape(-1, 1)
    labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
    labels[:safe.shape[0]] += 1
    examples = np.squeeze(np.vstack((safe, risky)))
    print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
    print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))

    print('Normality base rate (%):', round(100*(1 - err/100)*in_examples.shape[0]/
          (out_examples.shape[0] + (1 - err/100)*in_examples.shape[0]),2))
    print('Normality Detection (relative to correct examples)')
    safe, risky = r_right.reshape(-1,1), r_oos.reshape(-1, 1)
    labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
    labels[:safe.shape[0]] += 1
    examples = np.squeeze(np.vstack((safe, risky)))
    print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
    print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))

    print('\n\nAbnormality Detection')
    print('Abnormality base rate (%):', round(100*out_examples.shape[0]/(
                out_examples.shape[0] + in_examples.shape[0]),2))
    print('Abnormality Detection')
    safe, risky = 1 - r.reshape(-1,1), 1 - r_oos.reshape(-1, 1)
    labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
    labels[safe.shape[0]:] += 1
    examples = np.squeeze(np.vstack((safe, risky)))
    print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
    print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))

    print('Abnormality base rate (%):', round(100*out_examples.shape[0]/
          (out_examples.shape[0] + (1 - err/100)*in_examples.shape[0]),2))
    print('Abnormality Detection (relative to correct examples)')
    safe, risky = 1 - r_right.reshape(-1,1), 1 - r_oos.reshape(-1, 1)
    labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
    labels[safe.shape[0]:] += 1
    examples = np.squeeze(np.vstack((safe, risky)))
    print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
    print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))

In [24]:
print('Omniglot\n')
show_ood_detection_results(err, mnist.test.images, omni_images.reshape(-1, 28*28))


Omniglot

OOD Example Prediction Probability (mean, std):
0.878658 0.174975

Normality Detection
Normality base rate (%): 52.08
Normality Detection
AUPR (%): 99.56
AUROC (%): 99.5
Normality base rate (%): 51.72
Normality Detection (relative to correct examples)
AUPR (%): 99.62
AUROC (%): 99.59


Abnormality Detection
Abnormality base rate (%): 47.92
Abnormality Detection
AUPR (%): 99.45
AUROC (%): 99.5
Abnormality base rate (%): 48.28
Abnormality Detection (relative to correct examples)
AUPR (%): 99.57
AUROC (%): 99.59

In [25]:
print('notMNIST\n')
show_ood_detection_results(err, mnist.test.images, notmnist_dataset)


notMNIST

OOD Example Prediction Probability (mean, std):
0.917525 0.14954

Normality Detection
Normality base rate (%): 50.0
Normality Detection
AUPR (%): 99.99
AUROC (%): 99.99
Normality base rate (%): 49.63
Normality Detection (relative to correct examples)
AUPR (%): 99.99
AUROC (%): 99.99


Abnormality Detection
Abnormality base rate (%): 50.0
Abnormality Detection
AUPR (%): 99.99
AUROC (%): 99.99
Abnormality base rate (%): 50.37
Abnormality Detection (relative to correct examples)
AUPR (%): 99.99
AUROC (%): 99.99

In [26]:
print('CIFAR-10bw\n')
show_ood_detection_results(err, mnist.test.images, cifar_batch.reshape(-1, 28*28))


CIFAR-10bw

OOD Example Prediction Probability (mean, std):
0.792851 0.217157

Normality Detection
Normality base rate (%): 50.0
Normality Detection
AUPR (%): 99.93
AUROC (%): 99.93
Normality base rate (%): 49.63
Normality Detection (relative to correct examples)
AUPR (%): 99.95
AUROC (%): 99.95


Abnormality Detection
Abnormality base rate (%): 50.0
Abnormality Detection
AUPR (%): 99.93
AUROC (%): 99.93
Abnormality base rate (%): 50.37
Abnormality Detection (relative to correct examples)
AUPR (%): 99.95
AUROC (%): 99.95

In [27]:
print('Sheer White Gaussian Noise\n')
show_ood_detection_results(err, mnist.test.images, np.random.normal(size=(10000, 28*28)))


Sheer White Gaussian Noise

OOD Example Prediction Probability (mean, std):
0.920849 0.144049

Normality Detection
Normality base rate (%): 50.0
Normality Detection
AUPR (%): 100.0
AUROC (%): 100.0
Normality base rate (%): 49.63
Normality Detection (relative to correct examples)
AUPR (%): 100.0
AUROC (%): 100.0


Abnormality Detection
Abnormality base rate (%): 50.0
Abnormality Detection
AUPR (%): 100.0
AUROC (%): 100.0
Abnormality base rate (%): 50.37
Abnormality Detection (relative to correct examples)
AUPR (%): 100.0
AUROC (%): 100.0

In [28]:
print('Sheer Uniform Noise\n')
show_ood_detection_results(err, mnist.test.images, np.random.uniform(size=(10000, 28*28)))


Sheer Uniform Noise

OOD Example Prediction Probability (mean, std):
0.707209 0.228498

Normality Detection
Normality base rate (%): 50.0
Normality Detection
AUPR (%): 100.0
AUROC (%): 100.0
Normality base rate (%): 49.63
Normality Detection (relative to correct examples)
AUPR (%): 100.0
AUROC (%): 100.0


Abnormality Detection
Abnormality base rate (%): 50.0
Abnormality Detection
AUPR (%): 100.0
AUROC (%): 100.0
Abnormality base rate (%): 50.37
Abnormality Detection (relative to correct examples)
AUPR (%): 100.0
AUROC (%): 100.0