Tutorial Part 2: Learning MNIST Digit Classifiers

In the previous tutorial, we learned some basics of how to load data into DeepChem and how to use the basic DeepChem objects to load and manipulate this data. In this tutorial, you'll put the parts together and learn how to train a basic image classification model in DeepChem. You might ask, why are we bothering to learn this material in DeepChem? Part of the reason is that image processing is an increasingly important part of AI for the life sciences. So learning how to train image processing models will be very useful for using some of the more advanced DeepChem features.

The MNIST dataset contains handwritten digits along with their human annotated labels. The learning challenge for this dataset is to train a model that maps the digit image to its true label. MNIST has been a standard benchmark for machine learning for decades at this point.


This tutorial and the rest in this sequence are designed to be done in Google colab. If you'd like to open this notebook in colab, you can use the following link.


We recommend running this tutorial on Google colab. You'll need to run the following cell of installation commands on Colab to get your environment set up. If you'd rather run the tutorial locally, make sure you don't run these commands (since they'll download and install a new Anaconda python setup)

In [1]:
%tensorflow_version 1.x
!curl -Lo deepchem_installer.py https://raw.githubusercontent.com/deepchem/deepchem/master/scripts/colab_install.py
import deepchem_installer
%time deepchem_installer.install(version='2.3.0')

In [0]:
from tensorflow.examples.tutorials.mnist import input_data

In [3]:
# TODO: This is deprecated. Let's replace with a DeepChem native loader for maintainability.
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

WARNING:tensorflow:From <ipython-input-3-227956e9c9c1>:2: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From /tensorflow-1.15.2/python3.6/tensorflow_core/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From /tensorflow-1.15.2/python3.6/tensorflow_core/contrib/learn/python/learn/datasets/base.py:252: _internal_retry.<locals>.wrap.<locals>.wrapped_fn (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please use urllib or similar directly.
Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
WARNING:tensorflow:From /tensorflow-1.15.2/python3.6/tensorflow_core/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-images-idx3-ubyte.gz
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
WARNING:tensorflow:From /tensorflow-1.15.2/python3.6/tensorflow_core/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-labels-idx1-ubyte.gz
WARNING:tensorflow:From /tensorflow-1.15.2/python3.6/tensorflow_core/contrib/learn/python/learn/datasets/mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From /tensorflow-1.15.2/python3.6/tensorflow_core/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.

In [0]:
import deepchem as dc
import tensorflow as tf
from tensorflow.keras.layers import Reshape, Conv2D, Flatten, Dense, Softmax

In [0]:
train = dc.data.NumpyDataset(mnist.train.images, mnist.train.labels)
valid = dc.data.NumpyDataset(mnist.validation.images, mnist.validation.labels)

In [0]:
keras_model = tf.keras.Sequential([
    Reshape((28, 28, 1)),
    Conv2D(filters=32, kernel_size=5, activation=tf.nn.relu),
    Conv2D(filters=64, kernel_size=5, activation=tf.nn.relu),
    Dense(1024, activation=tf.nn.relu),
model = dc.models.KerasModel(keras_model, dc.models.losses.CategoricalCrossEntropy())

In [7]:
model.fit(train, nb_epoch=2)

WARNING:tensorflow:From /root/miniconda/lib/python3.6/site-packages/deepchem/models/keras_model.py:169: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.

WARNING:tensorflow:From /root/miniconda/lib/python3.6/site-packages/deepchem/models/optimizers.py:76: The name tf.train.AdamOptimizer is deprecated. Please use tf.compat.v1.train.AdamOptimizer instead.

WARNING:tensorflow:From /root/miniconda/lib/python3.6/site-packages/deepchem/models/keras_model.py:258: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead.

WARNING:tensorflow:From /root/miniconda/lib/python3.6/site-packages/deepchem/models/keras_model.py:260: The name tf.variables_initializer is deprecated. Please use tf.compat.v1.variables_initializer instead.

WARNING:tensorflow:From /root/miniconda/lib/python3.6/site-packages/deepchem/models/keras_model.py:200: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

WARNING:tensorflow:From /tensorflow-1.15.2/python3.6/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.

In [8]:
from sklearn.metrics import roc_curve, auc
import numpy as np

prediction = np.squeeze(model.predict_on_batch(valid.X))

fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(10):
    fpr[i], tpr[i], thresh = roc_curve(valid.y[:, i], prediction[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])
    print("class %s:auc=%s" % (i, roc_auc[i]))

class 0:auc=0.9999482812520925
class 1:auc=0.9999327470315621
class 2:auc=0.9999223382455529
class 3:auc=0.9999378924197698
class 4:auc=0.999804920932277
class 5:auc=0.9997608046652174
class 6:auc=0.9999347825797615
class 7:auc=0.9997099080694587
class 8:auc=0.999882187740275
class 9:auc=0.9996286953889618

