In this tutorial, we will train a Generative Adversarial Network (GAN) on the MNIST dataset. This is a large collection of 28x28 pixel images of handwritten digits. We will try to train a network to produce new images of handwritten digits.
This tutorial and the rest in this sequence are designed to be done in Google colab. If you'd like to open this notebook in colab, you can use the following link.
To run DeepChem within Colab, you'll need to run the following cell of installation commands. This will take about 5 minutes to run to completion and install your environment.
In [1]:
%tensorflow_version 1.x
!curl -Lo deepchem_installer.py https://raw.githubusercontent.com/deepchem/deepchem/master/scripts/colab_install.py
import deepchem_installer
%time deepchem_installer.install(version='2.3.0')
To begin, let's import all the libraries we'll need and load the dataset (which comes bundled with Tensorflow).
In [2]:
import deepchem as dc
import tensorflow as tf
from deepchem.models.optimizers import ExponentialDecay
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, Dense, Reshape
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plot
import matplotlib.gridspec as gridspec
%matplotlib inline
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
images = mnist.train.images.reshape((-1, 28, 28, 1))
dataset = dc.data.NumpyDataset(images)
Let's view some of the images to get an idea of what they look like.
In [3]:
def plot_digits(im):
plot.figure(figsize=(3, 3))
grid = gridspec.GridSpec(4, 4, wspace=0.05, hspace=0.05)
for i, g in enumerate(grid):
ax = plot.subplot(g)
ax.set_xticks([])
ax.set_yticks([])
ax.imshow(im[i,:,:,0], cmap='gray')
plot_digits(images)
Now we can create our GAN. It consists of two parts:
In [4]:
class DigitGAN(dc.models.WGAN):
def get_noise_input_shape(self):
return (10,)
def get_data_input_shapes(self):
return [(28, 28, 1)]
def create_generator(self):
return tf.keras.Sequential([
Dense(7*7*8, activation=tf.nn.relu),
Reshape((7, 7, 8)),
Conv2DTranspose(filters=16, kernel_size=5, strides=2, activation=tf.nn.relu, padding='same'),
Conv2DTranspose(filters=1, kernel_size=5, strides=2, activation=tf.sigmoid, padding='same')
])
def create_discriminator(self):
return tf.keras.Sequential([
Conv2D(filters=32, kernel_size=5, strides=2, activation=tf.nn.leaky_relu, padding='same'),
Conv2D(filters=64, kernel_size=5, strides=2, activation=tf.nn.leaky_relu, padding='same'),
Dense(1, activation=tf.math.softplus)
])
gan = DigitGAN(learning_rate=ExponentialDecay(0.001, 0.9, 5000))
Now to train it. The generator and discriminator are both trained together. The generator tries to get better at fooling the discriminator, while the discriminator tries to get better at distinguishing real data from generated data (which in turn gives the generator a better training signal to learn from).
In [5]:
def iterbatches(epochs):
for i in range(epochs):
for batch in dataset.iterbatches(batch_size=gan.batch_size):
yield {gan.data_inputs[0]: batch[0]}
gan.fit_gan(iterbatches(100), generator_steps=0.2, checkpoint_interval=5000)
Let's generate some data and see how the results look.
In [6]:
plot_digits(gan.predict_gan_generator(batch_size=16))
Not too bad. Many of the generated images look plausibly like handwritten digits. A larger model trained for a longer time can do much better, of course.
Congratulations on completing this tutorial notebook! If you enjoyed working through the tutorial, and want to continue working with DeepChem, we encourage you to finish the rest of the tutorials in this series. You can also help the DeepChem community in the following ways:
This helps build awareness of the DeepChem project and the tools for open source drug discovery that we're trying to build.
The DeepChem Gitter hosts a number of scientists, developers, and enthusiasts interested in deep learning for the life sciences. Join the conversation!