Here is astroNN, please take a look if you are interested in astronomy or how neural network applied in astronomy
For more resources on Bayesian Deep Learning with Dropout Variational Inference, please refer to README.md
In [2]:
%matplotlib inline
%config InlineBackend.figure_format='retina'
from tensorflow.keras.datasets import mnist
from tensorflow.keras import utils
import numpy as np
import pylab as plt
from astroNN.models import MNIST_BCNN
# disable eager execution to prevent potential error
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
In [3]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
y_train = utils.to_categorical(y_train, 10)
y_train = y_train.astype(np.float32)
x_train = x_train.astype(np.float32)
x_test = x_test.astype(np.float32)
# Create a astroNN neural network instance and set the basic parameter
net = MNIST_BCNN()
net.task = 'classification'
net.max_epochs = 5 # Just use 5 epochs for quick result
# Trian the nerual network
net.train(x_train, y_train)
You can see from below, most test images are right except the last one the model has a high uncertainty in it. As a human, you can indeed can argue this 5 is badly written can can be read as 6 or even a badly written 8.
In [4]:
test_idx = [1, 2, 3, 4, 5, 8]
pred, pred_std = net.test(x_test[test_idx])
for counter, i in enumerate(test_idx):
plt.figure(figsize=(3, 3), dpi=100)
plt.title(f'Predicted Digit {pred[counter]}, Real Answer: {y_test[i]:{1}} \n'
f'Total Uncertainty (Entropy): {(pred_std["total"][counter]):.{2}}')
plt.imshow(x_test[i])
plt.show()
plt.close('all')
plt.clf()
Since the neural network is trained on MNIST images without any data argumentation, so if we rotate the MNIST images, the images should look 'alien' to the neural network and the neural network should give us a high unceratinty. And indeed the neural network tells us its very uncertain about the prediction with roated images.
In [5]:
test_rot_idx = [9, 10, 11]
test_rot = x_test[test_rot_idx]
for counter, j in enumerate(test_rot):
test_rot[counter] = np.rot90(j)
pred_rot, pred_rot_std = net.test(test_rot)
for counter, i in enumerate(test_rot_idx):
plt.figure(figsize=(3, 3), dpi=100)
plt.title(f'Predicted Digit {pred_rot[counter]}, Real Answer: {y_test[i]:{1}} \n'
f'Total Uncertainty (Entropy): {(pred_rot_std["total"][counter]):.{2}}')
plt.imshow(test_rot[counter])
plt.show()
plt.close('all')
plt.clf()