In [1]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/")
import tensorflow as tf
import numpy as np
import sys
import os
import pickle
from load_cifar10 import load_data10
import sklearn.metrics as sk
# training parameters
learning_rate = 0.001
training_epochs = 30
batch_size = 128
# architecture parameters
n_labels = 10
image_pixels = 28 * 28
bottleneck = 128
In [2]:
def add_noise(batch, complexity=0.5):
return batch + np.random.normal(size=batch.shape, scale=1e-9 + complexity)
from skimage.filters import gaussian
def blur(img, complexity=0.5):
image = img.reshape((-1, 28, 28))
return gaussian(image, sigma=5*complexity).reshape((-1, 28*28))
In [3]:
graph = tf.Graph()
with graph.as_default():
x = tf.placeholder(dtype=tf.float32, shape=[None, image_pixels])
y = tf.placeholder(dtype=tf.int64, shape=[None])
risk_labels = tf.placeholder(dtype=tf.float32, shape=[None])
is_unfrozen = tf.placeholder(tf.bool)
def gelu_fast(x):
return 0.5 * x * (1 + tf.tanh(tf.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))
rho = gelu_fast
# if mode == 'input_restricted':
W = {}
b = {}
with tf.variable_scope("in_sample"):
W['1'] = tf.Variable(tf.nn.l2_normalize(tf.random_normal([image_pixels, 256]), 0))
W['2'] = tf.Variable(tf.nn.l2_normalize(tf.random_normal([256, 256]), 0))
W['3'] = tf.Variable(tf.nn.l2_normalize(tf.random_normal([256, 256]), 0))
W['logits'] = tf.Variable(tf.nn.l2_normalize(tf.random_normal([256, n_labels]), 0))
W['bottleneck'] = tf.Variable(tf.nn.l2_normalize(tf.random_normal([256, bottleneck]), 0))
W['decode1'] = tf.Variable(tf.nn.l2_normalize(tf.random_normal([bottleneck, 256]), 0))
W['decode2'] = tf.Variable(tf.nn.l2_normalize(tf.random_normal([256, 256]), 0))
W['reconstruction'] = tf.Variable(tf.nn.l2_normalize(tf.random_normal([256, image_pixels]), 0))
b['1'] = tf.Variable(tf.zeros([256]))
b['2'] = tf.Variable(tf.zeros([256]))
b['3'] = tf.Variable(tf.zeros([256]))
b['logits'] = tf.Variable(tf.zeros([n_labels]))
b['bottleneck'] = tf.Variable(tf.zeros([bottleneck]))
b['decode1'] = tf.Variable(tf.zeros([256]))
b['decode2'] = tf.Variable(tf.zeros([256]))
b['reconstruction'] = tf.Variable(tf.zeros([image_pixels]))
with tf.variable_scope("out_of_sample"):
W['residual_to_risk1'] = tf.Variable(tf.nn.l2_normalize(tf.random_normal([image_pixels, 512]), 0))
W['hidden_to_risk1'] = tf.Variable(tf.nn.l2_normalize(tf.random_normal([256, 512]), 0))
W['logits_to_risk1'] = tf.Variable(tf.nn.l2_normalize(tf.random_normal([n_labels, 512]), 0))
W['risk2'] = tf.Variable(tf.nn.l2_normalize(tf.random_normal([512, 128]), 0))
W['risk'] = tf.Variable(tf.nn.l2_normalize(tf.random_normal([128, 1]), 0))
b['risk1'] = tf.Variable(tf.zeros([512]))
b['risk2'] = tf.Variable(tf.zeros([128]))
b['risk'] = tf.Variable(tf.zeros([1]))
def risk_net(x):
h1 = rho(tf.matmul(x, W['1']) + b['1'])
h2 = rho(tf.matmul(h1, W['2']) + b['2'])
h3 = rho(tf.matmul(h2, W['3']) + b['3'])
logits_out = tf.matmul(h3, W['logits']) + b['logits']
hidden_to_bottleneck = rho(tf.matmul(h2, W['bottleneck']) + b['bottleneck'])
d1 = rho(tf.matmul(hidden_to_bottleneck, W['decode1']) + b['decode1'])
d2 = rho(tf.matmul(d1, W['decode2']) + b['decode2'])
recreation = tf.matmul(d2, W['reconstruction']) + b['reconstruction']
risk1 = rho(tf.matmul(logits_out, W['logits_to_risk1']) +
tf.matmul(tf.square(x - recreation), W['residual_to_risk1']) +
tf.matmul(h2, W['hidden_to_risk1']) + b['risk1'])
risk2 = rho(tf.matmul(risk1, W['risk2']) + b['risk2'])
risk_out = tf.matmul(risk2, W['risk'])
return logits_out, recreation, tf.squeeze(risk_out)
logits, reconstruction, risk = risk_net(x)
ce = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits, y))
rec_error = tf.reduce_mean(tf.square(x - reconstruction))
loss = 0.9 * ce + 0.1 * rec_error
lr = tf.constant(0.001)
optimizer = tf.train.AdamOptimizer(learning_rate=lr).minimize(loss)
compute_error = 100*tf.reduce_mean(tf.to_float(tf.not_equal(tf.argmax(logits, 1), y)))
compute_risk_error = 100*tf.reduce_mean(tf.to_float(tf.not_equal(tf.to_int64(tf.round(tf.sigmoid(risk))),
tf.to_int64(tf.round(risk_labels)))))
In [7]:
sess.close()
In [4]:
sess = tf.InteractiveSession(graph=graph)
print('Beginning training: Phase 1')
# Adam requires special care
in_sample_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "in_sample")
out_of_sample_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "out_of_sample")
sess.run(tf.initialize_variables(set(tf.all_variables()) - set(out_of_sample_vars)))
num_batches = int(mnist.train.num_examples / batch_size)
save_every = int(num_batches/3.1) # save training information 3 times per epoch
ce_ema = 2.3 # - log(0.1)
err_ema = 0.9
risk_loss_ema = 0.3 # - log(0.5)
learning_rate = 0.001
for epoch in range(training_epochs):
if epoch >= 20:
learning_rate = 0.0001
for i in range(num_batches):
bx, by = mnist.train.next_batch(batch_size)
_, err, l = sess.run([optimizer, compute_error, ce], feed_dict={x: bx, y: by, is_unfrozen: True,
lr: learning_rate})
ce_ema = ce_ema * 0.95 + 0.05 * l
err_ema = err_ema * 0.95 + 0.05 * err
In [5]:
print('Entering out-of-distribution training phase')
# Adam requires special care
risk_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(risk, risk_labels))
phase2_vars = list(set(tf.all_variables()) - set(in_sample_vars))
curr_lr = learning_rate
risk_optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(risk_loss, var_list=phase2_vars)
sess.run(tf.initialize_variables(set(tf.all_variables()) - set(in_sample_vars)))
err_ema = 50
for epoch in range(15):
for i in range(num_batches):
offset = i * batch_size
bx, by = mnist.train.next_batch(1.5*batch_size)
bx1 = bx[0:batch_size//6]
bx2 = mnist.validation.next_batch(batch_size//3)[0][batch_size//6:batch_size//3]
distortion = np.random.uniform(low=0.9, high=1.2)
bx3 = add_noise(bx[batch_size//3:4*batch_size//6], complexity=distortion)
distortion = np.random.uniform(low=0.9, high=1)
bx4 = blur(bx[4*batch_size//6:5*batch_size//6], complexity=distortion)
bx5 = np.zeros(shape=(batch_size - 5*batch_size//6, 28*28))
for k in range(5*batch_size//6, batch_size):
if by[k] == 0:
bx5[k - 5*batch_size//6] = add_noise(bx[k], complexity=1).reshape((28*28,))
else:
bx5[k - 5*batch_size//6] = np.rot90(bx[k].reshape((28,28)),
k=np.random.choice([1,3])).reshape((28*28,))
risks = np.zeros(batch_size)
risks[:batch_size//3] = 1
bx = np.vstack((bx1, bx2, bx3, bx4, bx5))
_, rl, err = sess.run([risk_optimizer, risk_loss, compute_risk_error],
feed_dict={x: bx, risk_labels: risks, is_unfrozen: False})
risk_loss_ema = risk_loss_ema * 0.95 + 0.05 * rl
err_ema = err_ema * 0.95 + 0.05 * err
print('Epoch:', epoch, '|', 'ema of risk for epoch:', risk_loss_ema, 'error (%):', err_ema)
In [6]:
# load notMNIST, CIFAR-10, and Omniglot
pickle_file = './data/notMNIST.pickle'
with open(pickle_file, 'rb') as f:
save = pickle.load(f, encoding='latin1')
notmnist_dataset = save['test_dataset'].reshape((-1, 28 * 28))
del save
_, _, X_test, _ = load_data10()
cifar_batch = sess.run(tf.image.resize_images(tf.image.rgb_to_grayscale(X_test), 28, 28))
import scipy.io as sio
import scipy.misc as scimisc
# other alphabets have characters which overlap
safe_list = [0,2,5,6,8,12,13,14,15,16,17,18,19,21,26]
m = sio.loadmat("./data/data_background.mat")
squished_set = []
for safe_number in safe_list:
for alphabet in m['images'][safe_number]:
for letters in alphabet:
for letter in letters:
for example in letter:
squished_set.append(scimisc.imresize(1 - example[0], (28,28)).reshape(1, 28*28))
omni_images = np.concatenate(squished_set, axis=0)
In [7]:
err, r_err, r, conf = sess.run([compute_error, compute_risk_error, tf.sigmoid(risk), tf.nn.softmax(logits)],
feed_dict={x: mnist.test.images, y: mnist.test.labels, risk_labels: np.ones(10000)})
r_right = r[np.argmax(conf, axis=1).astype(np.uint8) == mnist.test.labels]
print('MNIST Digit Error (%) | MNIST Riskiness Error (%) | Digit Confidence (mean, std):')
print(err, '|', r_err, '|', np.mean(np.max(conf, axis=1)), np.std(np.max(conf, axis=1)))
In [9]:
s = tf.nn.softmax(logits)
s_prob = tf.reduce_max(s, reduction_indices=[1], keep_dims=True)
kl_all = tf.log(10.) + tf.reduce_sum(s * tf.log(tf.abs(s) + 1e-11), reduction_indices=[1], keep_dims=True)
m_all, v_all = tf.nn.moments(kl_all, axes=[0])
logits_right = tf.boolean_mask(logits, tf.equal(tf.argmax(logits, 1), y))
s_right = tf.nn.softmax(logits_right)
s_right_prob = tf.reduce_max(s_right, reduction_indices=[1], keep_dims=True)
kl_right = tf.log(10.) + tf.reduce_sum(s_right * tf.log(tf.abs(s_right) + 1e-11),
reduction_indices=[1], keep_dims=True)
m_right, v_right = tf.nn.moments(kl_right, axes=[0])
logits_wrong = tf.boolean_mask(logits, tf.not_equal(tf.argmax(logits, 1), y))
s_wrong = tf.nn.softmax(logits_wrong)
s_wrong_prob = tf.reduce_max(s_wrong, reduction_indices=[1], keep_dims=True)
kl_wrong = tf.log(10.) + tf.reduce_sum(s_wrong * tf.log(tf.abs(s_wrong) + 1e-11),
reduction_indices=[1], keep_dims=True)
m_wrong, v_wrong = tf.nn.moments(kl_wrong, axes=[0])
In [12]:
kl_a, kl_r, kl_w, s_p, s_rp, s_wp = sess.run(
[kl_all, kl_right, kl_wrong, s_prob, s_right_prob, s_wrong_prob],
feed_dict={x: mnist.test.images, y: mnist.test.labels})
print('\nSuccess Detection')
print('Success base rate (%):', round(100-err,2))
print('KL[p||u]: Right/Wrong classification distinction')
safe, risky = kl_r, kl_w
labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
labels[:safe.shape[0]] += 1
examples = np.squeeze(np.vstack((safe, risky)))
print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))
print('Prediction Prob: Right/Wrong classification distinction')
safe, risky = s_rp, s_wp
labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
labels[:safe.shape[0]] += 1
examples = np.squeeze(np.vstack((safe, risky)))
print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))
print('\nError Detection')
print('Error base rate (%):', round(err,2))
safe, risky = -kl_r, -kl_w
labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
labels[safe.shape[0]:] += 1
examples = np.squeeze(np.vstack((safe, risky)))
print('KL[p||u]: Right/Wrong classification distinction')
print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))
print('Prediction Prob: Right/Wrong classification distinction')
safe, risky = -s_rp, -s_wp
labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
labels[safe.shape[0]:] += 1
examples = np.squeeze(np.vstack((safe, risky)))
print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))
In [17]:
def show_ood_detection_results_softmax(error_rate_for_in, in_examples, out_examples):
kl_oos, s_p_oos = sess.run([kl_all, s_prob], feed_dict={x: out_examples})
print('OOD Example Prediction Probability (mean, std):')
print(np.mean(s_p_oos), np.std(s_p_oos))
print('\nNormality Detection')
print('Normality base rate (%):', round(100*in_examples.shape[0]/(
out_examples.shape[0] + in_examples.shape[0]),2))
print('KL[p||u]: Normality Detection')
safe, risky = kl_a, kl_oos
labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
labels[:safe.shape[0]] += 1
examples = np.squeeze(np.vstack((safe, risky)))
print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))
print('Prediction Prob: Normality Detection')
safe, risky = s_p, s_p_oos
labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
labels[:safe.shape[0]] += 1
examples = np.squeeze(np.vstack((safe, risky)))
print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))
print('Normality base rate (%):', round(100*(1 - err/100)*in_examples.shape[0]/
(out_examples.shape[0] + (1 - err/100)*in_examples.shape[0]),2))
print('KL[p||u]: Normality Detection (relative to correct examples)')
safe, risky = kl_r, kl_oos
labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
labels[:safe.shape[0]] += 1
examples = np.squeeze(np.vstack((safe, risky)))
print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))
print('Prediction Prob: Normality Detection (relative to correct examples)')
safe, risky = s_rp, s_p_oos
labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
labels[:safe.shape[0]] += 1
examples = np.squeeze(np.vstack((safe, risky)))
print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))
print('\n\nAbnormality Detection')
print('Abnormality base rate (%):', round(100*out_examples.shape[0]/(
out_examples.shape[0] + in_examples.shape[0]),2))
print('KL[p||u]: Abnormality Detection')
safe, risky = -kl_a, -kl_oos
labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
labels[safe.shape[0]:] += 1
examples = np.squeeze(np.vstack((safe, risky)))
print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))
print('Prediction Prob: Abnormality Detection')
safe, risky = -s_p, -s_p_oos
labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
labels[safe.shape[0]:] += 1
examples = np.squeeze(np.vstack((safe, risky)))
print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))
print('Abnormality base rate (%):', round(100*out_examples.shape[0]/
(out_examples.shape[0] + (1 - err/100)*in_examples.shape[0]),2))
print('KL[p||u]: Abnormality Detection (relative to correct examples)')
safe, risky = -kl_r, -kl_oos
labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
labels[safe.shape[0]:] += 1
examples = np.squeeze(np.vstack((safe, risky)))
print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))
print('Prediction Prob: Abnormality Detection (relative to correct examples)')
safe, risky = -s_rp, -s_p_oos
labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
labels[safe.shape[0]:] += 1
examples = np.squeeze(np.vstack((safe, risky)))
print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))
In [18]:
print('Omniglot (Softmax version)\n')
show_ood_detection_results_softmax(err, mnist.test.images, omni_images.reshape(-1, 28*28))
In [19]:
print('notMNIST (Softmax version)\n')
show_ood_detection_results_softmax(err, mnist.test.images, notmnist_dataset)
In [20]:
print('CIFAR-10bw (Softmax Version)\n')
show_ood_detection_results_softmax(err, mnist.test.images, cifar_batch.reshape(-1, 28*28))
In [21]:
print('Sheer White Gaussian Noise (Softmax Version)\n')
show_ood_detection_results_softmax(err, mnist.test.images, np.random.normal(size=(10000, 28*28)))
In [22]:
print('Sheer Uniform Noise (Softmax Version)\n')
show_ood_detection_results_softmax(err, mnist.test.images, np.random.uniform(size=(10000, 28*28)))
In [23]:
def show_ood_detection_results(error_rate_for_in, in_examples, out_examples):
r_oos, conf = sess.run([tf.sigmoid(risk), tf.nn.softmax(logits)], feed_dict={x: out_examples})
print('OOD Example Prediction Probability (mean, std):')
print(np.mean(np.max(conf, axis=1)), np.std(np.max(conf, axis=1)))
print('\nNormality Detection')
print('Normality base rate (%):', round(100*in_examples.shape[0]/(
out_examples.shape[0] + in_examples.shape[0]),2))
print('Normality Detection')
safe, risky = r.reshape(-1,1), r_oos.reshape(-1, 1)
labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
labels[:safe.shape[0]] += 1
examples = np.squeeze(np.vstack((safe, risky)))
print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))
print('Normality base rate (%):', round(100*(1 - err/100)*in_examples.shape[0]/
(out_examples.shape[0] + (1 - err/100)*in_examples.shape[0]),2))
print('Normality Detection (relative to correct examples)')
safe, risky = r_right.reshape(-1,1), r_oos.reshape(-1, 1)
labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
labels[:safe.shape[0]] += 1
examples = np.squeeze(np.vstack((safe, risky)))
print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))
print('\n\nAbnormality Detection')
print('Abnormality base rate (%):', round(100*out_examples.shape[0]/(
out_examples.shape[0] + in_examples.shape[0]),2))
print('Abnormality Detection')
safe, risky = 1 - r.reshape(-1,1), 1 - r_oos.reshape(-1, 1)
labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
labels[safe.shape[0]:] += 1
examples = np.squeeze(np.vstack((safe, risky)))
print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))
print('Abnormality base rate (%):', round(100*out_examples.shape[0]/
(out_examples.shape[0] + (1 - err/100)*in_examples.shape[0]),2))
print('Abnormality Detection (relative to correct examples)')
safe, risky = 1 - r_right.reshape(-1,1), 1 - r_oos.reshape(-1, 1)
labels = np.zeros((safe.shape[0] + risky.shape[0]), dtype=np.int32)
labels[safe.shape[0]:] += 1
examples = np.squeeze(np.vstack((safe, risky)))
print('AUPR (%):', round(100*sk.average_precision_score(labels, examples), 2))
print('AUROC (%):', round(100*sk.roc_auc_score(labels, examples), 2))
In [24]:
print('Omniglot\n')
show_ood_detection_results(err, mnist.test.images, omni_images.reshape(-1, 28*28))
In [25]:
print('notMNIST\n')
show_ood_detection_results(err, mnist.test.images, notmnist_dataset)
In [26]:
print('CIFAR-10bw\n')
show_ood_detection_results(err, mnist.test.images, cifar_batch.reshape(-1, 28*28))
In [27]:
print('Sheer White Gaussian Noise\n')
show_ood_detection_results(err, mnist.test.images, np.random.normal(size=(10000, 28*28)))
In [28]:
print('Sheer Uniform Noise\n')
show_ood_detection_results(err, mnist.test.images, np.random.uniform(size=(10000, 28*28)))