Tutorial Part 2: Learning MNIST Digit Classifiers

In the previous tutorial, we learned some basics of how to load data into DeepChem and how to use the basic DeepChem objects to load and manipulate this data. In this tutorial, you'll put the parts together and learn how to train a basic image classification model in DeepChem. You might ask, why are we bothering to learn this material in DeepChem? Part of the reason is that image processing is an increasingly important part of AI for the life sciences. So learning how to train image processing models will be very useful for using some of the more advanced DeepChem features.

The MNIST dataset contains handwritten digits along with their human annotated labels. The learning challenge for this dataset is to train a model that maps the digit image to its true label. MNIST has been a standard benchmark for machine learning for decades at this point.

For convenience, TensorFlow has provided some loader methods to get access to the MNIST dataset. We'll make use of these loaders.

In [1]:
from tensorflow.examples.tutorials.mnist import input_data

In [2]:
# TODO: This is deprecated. Let's replace with a DeepChem native loader for maintainability.
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

WARNING:tensorflow:From <ipython-input-2-a839aeb82f4b>:1: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From /home/bharath/anaconda3/envs/deepchem/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From /home/bharath/anaconda3/envs/deepchem/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:252: _internal_retry.<locals>.wrap.<locals>.wrapped_fn (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please use urllib or similar directly.
Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
WARNING:tensorflow:From /home/bharath/anaconda3/envs/deepchem/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-images-idx3-ubyte.gz
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
WARNING:tensorflow:From /home/bharath/anaconda3/envs/deepchem/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-labels-idx1-ubyte.gz
WARNING:tensorflow:From /home/bharath/anaconda3/envs/deepchem/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From /home/bharath/anaconda3/envs/deepchem/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.

In [3]:
import deepchem as dc
import tensorflow as tf
from deepchem.models.tensorgraph.layers import Layer, Input, Reshape, Flatten, Conv2D, Label, Feature
from deepchem.models.tensorgraph.layers import Dense, SoftMaxCrossEntropy, ReduceMean, SoftMax

In [4]:
train = dc.data.NumpyDataset(mnist.train.images, mnist.train.labels)
valid = dc.data.NumpyDataset(mnist.validation.images, mnist.validation.labels)

In [5]:
tg = dc.models.TensorGraph(tensorboard=True, model_dir='/tmp/mnist', use_queue=False)
feature = Feature(shape=(None, 784))

# Images are square 28x28 (batch, height, width, channel)
make_image = Reshape(shape=(-1, 28, 28, 1), in_layers=[feature])

conv2d_1 = Conv2D(num_outputs=32, in_layers=[make_image])

conv2d_2 = Conv2D(num_outputs=64, in_layers=[conv2d_1])

flatten = Flatten(in_layers=[conv2d_2])

dense1 = Dense(out_channels=1024, activation_fn=tf.nn.relu, in_layers=[flatten])

dense2 = Dense(out_channels=10, in_layers=[dense1])

label = Label(shape=(None, 10))

smce = SoftMaxCrossEntropy(in_layers=[label, dense2])
loss = ReduceMean(in_layers=[smce])

output = SoftMax(in_layers=[dense2])

In [6]:
# nb_epoch set to 0 to permit rendering of tutorials online.
# Set nb_epoch=10 for better results
tg.fit(train, nb_epoch=0)

TIMING: model fitting took 2.621 s

In [7]:
# Note that AUCs will be nonsensical without setting nb_epoch higher!
from sklearn.metrics import roc_curve, auc
import numpy as np

prediction = np.squeeze(tg.predict_on_batch(valid.X))

fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(10):
    fpr[i], tpr[i], thresh = roc_curve(valid.y[:, i], prediction[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])
    print("class %s:auc=%s" % (i, roc_auc[i]))

class 0:auc=0.170555039138
class 1:auc=0.634619025945
class 2:auc=0.303693111629
class 3:auc=0.549712392397
class 4:auc=0.523565007169
class 5:auc=0.648496399959
class 6:auc=0.316489936331
class 7:auc=0.708521348315
class 8:auc=0.555897147512
class 9:auc=0.377021042837