In [1]:
# import MNIST data, Tensorflow, and other helpers
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/")
import tensorflow as tf
import numpy as np
import sys
import os
import pickle
from load_cifar10 import load_data10
import sklearn.metrics as sk
# training parameters
training_epochs = 30
batch_size = 128
# architecture parameters
n_labels = 10
image_pixels = 28 * 28
In [3]:
graph = tf.Graph()
with graph.as_default():
x = tf.placeholder(dtype=tf.float32, shape=[None, image_pixels])
y = tf.placeholder(dtype=tf.int64, shape=[None])
risk_labels = tf.placeholder(dtype=tf.float32, shape=[None])
def gelu_fast(x):
return 0.5 * x * (1 + tf.tanh(tf.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))
rho = gelu_fast
W = {}
b = {}
W['1'] = tf.Variable(tf.nn.l2_normalize(tf.random_normal([image_pixels, 256]), 0)/tf.sqrt(1 + 0.425))
W['2'] = tf.Variable(tf.nn.l2_normalize(tf.random_normal([256, 256]), 0)/tf.sqrt(0.425 + 0.425))
W['3'] = tf.Variable(tf.nn.l2_normalize(tf.random_normal([256, 256]), 0)/tf.sqrt(0.425 + 0.425))
W['logits'] = tf.Variable(tf.nn.l2_normalize(tf.random_normal([256, n_labels]), 0)/tf.sqrt(0.425 + 1))
b['1'] = tf.Variable(tf.zeros([256]))
b['2'] = tf.Variable(tf.zeros([256]))
b['3'] = tf.Variable(tf.zeros([256]))
b['logits'] = tf.Variable(tf.zeros([n_labels]))
def model(x):
h1 = rho(tf.matmul(x, W['1']) + b['1'])
h2 = rho(tf.matmul(h1, W['2']) + b['2'])
h3 = rho(tf.matmul(h2, W['3']) + b['3'])
return tf.matmul(h3, W['logits']) + b['logits']
logits = model(x)
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits, y))
loss += 1e-5 * (tf.nn.l2_loss(W['1']) + tf.nn.l2_loss(W['2']) + tf.nn.l2_loss(W['3']))
lr = tf.constant(0.001)
optimizer = tf.train.AdamOptimizer(learning_rate=lr).minimize(loss)
compute_error = tf.reduce_mean(tf.to_float(tf.not_equal(tf.argmax(logits, 1), y)))
In [4]:
sess = tf.InteractiveSession(graph=graph)
print('Beginning training')
sess.run(tf.initialize_all_variables())
num_batches = int(mnist.train.num_examples / batch_size)
ce_ema = 2.3 # - log(0.1)
err_ema = 0.9
risk_loss_ema = 0.3 # - log(0.5)
learning_rate = 0.001
for epoch in range(training_epochs):
if epoch >= 20:
learning_rate = 0.0001
for i in range(num_batches):
bx, by = mnist.train.next_batch(batch_size)
_, err, l = sess.run([optimizer, compute_error, loss], feed_dict={x: bx, y: by, lr: learning_rate})
ce_ema = ce_ema * 0.95 + 0.05 * l
err_ema = err_ema * 0.95 + 0.05 * err
# we're training on all data so we do not keep the validation set separate
for i in range(mnist.validation.num_examples//batch_size):
bx, by = mnist.validation.next_batch(batch_size)
_, err, l = sess.run([optimizer, compute_error, loss], feed_dict={x: bx, y: by, lr: learning_rate})
ce_ema = ce_ema * 0.95 + 0.05 * l
err_ema = err_ema * 0.95 + 0.05 * err
print('Epoch:', epoch, '|', 'ce ema of loss for epoch:', ce_ema, 'error (%):', 100*err_ema)
print('MNIST classification loss and error:', sess.run([loss, 100*compute_error],
feed_dict={x: mnist.test.images, y: mnist.test.labels}))
In [5]:
s = tf.nn.softmax(logits)
s_prob = tf.reduce_max(s, reduction_indices=[1], keep_dims=True)
kl_all = tf.log(10.) + tf.reduce_sum(s * tf.log(tf.abs(s) + 1e-11), reduction_indices=[1], keep_dims=True)
m_all, v_all = tf.nn.moments(kl_all, axes=[0])
logits_right = tf.boolean_mask(logits, tf.equal(tf.argmax(logits, 1), y))
s_right = tf.nn.softmax(logits_right)
s_right_prob = tf.reduce_max(s_right, reduction_indices=[1], keep_dims=True)
kl_right = tf.log(10.) + tf.reduce_sum(s_right * tf.log(tf.abs(s_right) + 1e-11), reduction_indices=[1], keep_dims=True)
m_right, v_right = tf.nn.moments(kl_right, axes=[0])
logits_wrong = tf.boolean_mask(logits, tf.not_equal(tf.argmax(logits, 1), y))
s_wrong = tf.nn.softmax(logits_wrong)
s_wrong_prob = tf.reduce_max(s_wrong, reduction_indices=[1], keep_dims=True)
kl_wrong = tf.log(10.) + tf.reduce_sum(s_wrong * tf.log(tf.abs(s_wrong) + 1e-11), reduction_indices=[1], keep_dims=True)
m_wrong, v_wrong = tf.nn.moments(kl_wrong, axes=[0])
acc = 100*tf.reduce_mean(tf.to_float(tf.equal(tf.argmax(logits, 1), y)))
In [14]:
err, kl_a, kl_r, kl_w, s_p, s_rp, s_wp = sess.run(
[100 - acc, kl_all, kl_right, kl_wrong, s_prob, s_right_prob, s_wrong_prob],
feed_dict={x: mnist.test.images, y: mnist.test.labels})
print('MNIST Error (%)| Prediction Prob (mean, std) | PProb Right (mean, std) | PProb Wrong (mean, std):')
print(err, '|', np.mean(s_p), np.std(s_p), '|', np.mean(s_rp), np.std(s_rp), '|', np.mean(s_wp), np.std(s_wp))
print('\nSuccess Detection')
print('Success base rate (%):', round(100-err,2))
print('KL[p||u]: Right/Wrong classification distinction')
safe, risky = kl_r, kl_w
labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
labels[:safe.shape[0]] += 1
examples = np.squeeze(np.vstack((safe, risky)))
print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))
print('Prediction Prob: Right/Wrong classification distinction')
safe, risky = s_rp, s_wp
labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
labels[:safe.shape[0]] += 1
examples = np.squeeze(np.vstack((safe, risky)))
print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))
print('\nError Detection')
print('Error base rate (%):', round(err,2))
safe, risky = -kl_r, -kl_w
labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
labels[safe.shape[0]:] += 1
examples = np.squeeze(np.vstack((safe, risky)))
print('KL[p||u]: Right/Wrong classification distinction')
print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))
print('Prediction Prob: Right/Wrong classification distinction')
safe, risky = -s_rp, -s_wp
labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
labels[safe.shape[0]:] += 1
examples = np.squeeze(np.vstack((safe, risky)))
print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))
In [32]:
def show_ood_detection_results(error_rate_for_in, in_examples, out_examples):
kl_oos, s_p_oos = sess.run([kl_all, s_prob], feed_dict={x: out_examples})
print('OOD Example Prediction Probability (mean, std):')
print(np.mean(s_p_oos), np.std(s_p_oos))
print('\nNormality Detection')
print('Normality base rate (%):', round(100*in_examples.shape[0]/(
out_examples.shape[0] + in_examples.shape[0]),2))
print('KL[p||u]: Normality Detection')
safe, risky = kl_a, kl_oos
labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
labels[:safe.shape[0]] += 1
examples = np.squeeze(np.vstack((safe, risky)))
print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))
print('Prediction Prob: Normality Detection')
safe, risky = s_p, s_p_oos
labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
labels[:safe.shape[0]] += 1
examples = np.squeeze(np.vstack((safe, risky)))
print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))
print('Normality base rate (%):', round(100*(1 - err/100)*in_examples.shape[0]/
(out_examples.shape[0] + (1 - err/100)*in_examples.shape[0]),2))
print('KL[p||u]: Normality Detection (relative to correct examples)')
safe, risky = kl_r, kl_oos
labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
labels[:safe.shape[0]] += 1
examples = np.squeeze(np.vstack((safe, risky)))
print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))
print('Prediction Prob: Normality Detection (relative to correct examples)')
safe, risky = s_rp, s_p_oos
labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
labels[:safe.shape[0]] += 1
examples = np.squeeze(np.vstack((safe, risky)))
print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))
print('\n\nAbnormality Detection')
print('Abnormality base rate (%):', round(100*out_examples.shape[0]/(
out_examples.shape[0] + in_examples.shape[0]),2))
print('KL[p||u]: Abnormality Detection')
safe, risky = -kl_a, -kl_oos
labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
labels[safe.shape[0]:] += 1
examples = np.squeeze(np.vstack((safe, risky)))
print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))
print('Prediction Prob: Abnormality Detection')
safe, risky = -s_p, -s_p_oos
labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
labels[safe.shape[0]:] += 1
examples = np.squeeze(np.vstack((safe, risky)))
print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))
print('Abnormality base rate (%):', round(100*out_examples.shape[0]/
(out_examples.shape[0] + (1 - err/100)*in_examples.shape[0]),2))
print('KL[p||u]: Abnormality Detection (relative to correct examples)')
safe, risky = -kl_r, -kl_oos
labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
labels[safe.shape[0]:] += 1
examples = np.squeeze(np.vstack((safe, risky)))
print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))
print('Prediction Prob: Abnormality Detection (relative to correct examples)')
safe, risky = -s_rp, -s_p_oos
labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
labels[safe.shape[0]:] += 1
examples = np.squeeze(np.vstack((safe, risky)))
print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))
In [12]:
import scipy.io as sio
import scipy.misc as scimisc
In [13]:
safe_list = [0,2,5,6,8,12,13,14,15,16,17,18,19,21,26] # other alphabets have characters which look like digits
m = sio.loadmat("./data/data_background.mat")
squished_set = []
for safe_number in safe_list:
for alphabet in m['images'][safe_number]:
for letters in alphabet:
for letter in letters:
for example in letter:
squished_set.append(scimisc.imresize(1 - example[0], (28,28)).reshape(1, 28*28))
safe_images = np.concatenate(squished_set, axis=0)
In [33]:
print('Omniglot\n')
show_ood_detection_results(err, mnist.test.images, safe_images)
In [34]:
pickle_file = './data/notMNIST.pickle'
with open(pickle_file, 'rb') as f:
save = pickle.load(f, encoding='latin1')
test_dataset = save['test_dataset'].reshape((-1, 28 * 28))
del save # hint to help gc free up memory
print('notMNIST\n')
show_ood_detection_results(err, mnist.test.images, test_dataset)
In [35]:
_, _, X_test, _ = load_data10()
cifar_batch = sess.run(tf.image.resize_images(tf.image.rgb_to_grayscale(X_test), 28, 28))
print('CIFAR-10bw\n')
show_ood_detection_results(err, mnist.test.images, cifar_batch.reshape(-1, 28*28))
In [36]:
print('Sheer White Gaussian Noise\n')
show_ood_detection_results(err, mnist.test.images, np.random.normal(size=(10000, 28*28)))
# caveat: if you let the scale = 100 this will fail because it is inputting to much
# energy into the network, but this is unrealistic
In [37]:
print('Sheer Uniform Noise\n')
show_ood_detection_results(err, mnist.test.images, np.random.uniform(size=(10000, 28*28)))