Tutorial Part 16: Conditional Generative Adversarial Network

Note: This example implements a GAN from scratch. The same model could be implemented much more easily with the dc.models.GAN class. See the MNIST GAN notebook for an example of using that class. It can still be useful to know how to implement a GAN from scratch for advanced situations that are beyond the scope of what the standard GAN class supports.

A Generative Adversarial Network (GAN) is a type of generative model. It consists of two parts called the "generator" and the "discriminator". The generator takes random values as input and transforms them into an output that (hopefully) resembles the training data. The discriminator takes a set of samples as input and tries to distinguish the real training samples from the ones created by the generator. Both of them are trained together. The discriminator tries to get better and better at telling real from false data, while the generator tries to get better and better at fooling the discriminator.

A Conditional GAN (CGAN) allows additional inputs to the generator and discriminator that their output is conditioned on. For example, this might be a class label, and the GAN tries to learn how the data distribution varies between classes.


This tutorial and the rest in this sequence are designed to be done in Google colab. If you'd like to open this notebook in colab, you can use the following link.


To run DeepChem within Colab, you'll need to run the following cell of installation commands. This will take about 5 minutes to run to completion and install your environment.

In [1]:
%tensorflow_version 1.x
!curl -Lo deepchem_installer.py https://raw.githubusercontent.com/deepchem/deepchem/master/scripts/colab_install.py
import deepchem_installer
%time deepchem_installer.install(version='2.3.0')

TensorFlow 1.x selected.
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  3477  100  3477    0     0   7902      0 --:--:-- --:--:-- --:--:--  7884
add /root/miniconda/lib/python3.6/site-packages to PYTHONPATH
python version: 3.6.9
fetching installer from https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh
installing miniconda to /root/miniconda
installing deepchem
/usr/local/lib/python3.6/dist-packages/sklearn/externals/joblib/__init__.py:15: FutureWarning: sklearn.externals.joblib is deprecated in 0.21 and will be removed in 0.23. Please import this functionality directly from joblib, which can be installed with: pip install joblib. If this warning is raised when loading pickled models, you may need to re-serialize those models with scikit-learn 0.21+.
  warnings.warn(msg, category=FutureWarning)
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

deepchem-2.3.0 installation finished!
CPU times: user 2.44 s, sys: 517 ms, total: 2.96 s
Wall time: 1min 57s

For this example, we will create a data distribution consisting of a set of ellipses in 2D, each with a random position, shape, and orientation. Each class corresponds to a different ellipse. Let's randomly generate the ellipses.

In [0]:
import deepchem as dc
import numpy as np
import tensorflow as tf

n_classes = 4
class_centers = np.random.uniform(-4, 4, (n_classes, 2))
class_transforms = []
for i in range(n_classes):
    xscale = np.random.uniform(0.5, 2)
    yscale = np.random.uniform(0.5, 2)
    angle = np.random.uniform(0, np.pi)
    m = [[xscale*np.cos(angle), -yscale*np.sin(angle)],
         [xscale*np.sin(angle), yscale*np.cos(angle)]]
class_transforms = np.array(class_transforms)

This function generates random data from the distribution. For each point it chooses a random class, then a random position in that class' ellipse.

In [0]:
def generate_data(n_points):
    classes = np.random.randint(n_classes, size=n_points)
    r = np.random.random(n_points)
    angle = 2*np.pi*np.random.random(n_points)
    points = (r*np.array([np.cos(angle), np.sin(angle)])).T
    points = np.einsum('ijk,ik->ij', class_transforms[classes], points)
    points += class_centers[classes]
    return classes, points

Let's plot a bunch of random points drawn from this distribution to see what it looks like. Points are colored based on their class label.

In [4]:
%matplotlib inline
import matplotlib.pyplot as plot
classes, points = generate_data(1000)
plot.scatter(x=points[:,0], y=points[:,1], c=classes)

<matplotlib.collections.PathCollection at 0x7f89693c4dd8>

Now let's create the model for our CGAN.

In [0]:
import deepchem.models.tensorgraph.layers as layers
model = dc.models.TensorGraph(learning_rate=1e-4, use_queue=False)

# Inputs to the model

random_in = layers.Feature(shape=(None, 10)) # Random input to the generator
generator_classes = layers.Feature(shape=(None, n_classes)) # The classes of the generated samples
real_data_points = layers.Feature(shape=(None, 2)) # The training samples
real_data_classes = layers.Feature(shape=(None, n_classes)) # The classes of the training samples
is_real = layers.Weights(shape=(None, 1)) # Flags to distinguish real from generated samples

# The generator

gen_in = layers.Concat([random_in, generator_classes])
gen_dense1 = layers.Dense(30, in_layers=gen_in, activation_fn=tf.nn.relu)
gen_dense2 = layers.Dense(30, in_layers=gen_dense1, activation_fn=tf.nn.relu)
generator_points = layers.Dense(2, in_layers=gen_dense2)

# The discriminator

all_points = layers.Concat([generator_points, real_data_points], axis=0)
all_classes = layers.Concat([generator_classes, real_data_classes], axis=0)
discrim_in = layers.Concat([all_points, all_classes])
discrim_dense1 = layers.Dense(30, in_layers=discrim_in, activation_fn=tf.nn.relu)
discrim_dense2 = layers.Dense(30, in_layers=discrim_dense1, activation_fn=tf.nn.relu)
discrim_prob = layers.Dense(1, in_layers=discrim_dense2, activation_fn=tf.sigmoid)

We'll use different loss functions for training the generator and discriminator. The discriminator outputs its predictions in the form of a probability that each sample is a real sample (that is, that it came from the training set rather than the generator). Its loss consists of two terms. The first term tries to maximize the output probability for real data, and the second term tries to minimize the output probability for generated samples. The loss function for the generator is just a single term: it tries to maximize the discriminator's output probability for generated samples.

For each one, we create a "submodel" specifying a set of layers that will be optimized based on a loss function.

In [0]:
# Discriminator

discrim_real_data_loss = -layers.Log(discrim_prob+1e-10) * is_real
discrim_gen_data_loss = -layers.Log(1-discrim_prob+1e-10) * (1-is_real)
discrim_loss = layers.ReduceMean(discrim_real_data_loss + discrim_gen_data_loss)
discrim_submodel = model.create_submodel(layers=[discrim_dense1, discrim_dense2, discrim_prob], loss=discrim_loss)

# Generator

gen_loss = -layers.ReduceMean(layers.Log(discrim_prob+1e-10) * (1-is_real))
gen_submodel = model.create_submodel(layers=[gen_dense1, gen_dense2, generator_points], loss=gen_loss)

Now to fit the model. Here are some important points to notice about the code.

  • We use fit_generator() to train only a single batch at a time, and we alternate between the discriminator and the generator. That way. both parts of the model improve together.
  • We only train the generator half as often as the discriminator. On this particular model, that gives much better results. You will often need to adjust (# of discriminator steps)/(# of generator steps) to get good results on a given problem.
  • We disable checkpointing by specifying checkpoint_interval=0. Since each call to fit_generator() includes only a single batch, it would otherwise save a checkpoint to disk after every batch, which would be very slow. If this were a real project and not just an example, we would want to occasionally call model.save_checkpoint() to write checkpoints at a reasonable interval.

In [7]:
batch_size = model.batch_size
discrim_error = []
gen_error = []
for step in range(20000):
    classes, points = generate_data(batch_size)
    class_flags = dc.metrics.to_one_hot(classes, n_classes)
    feed_dict={random_in: np.random.random((batch_size, 10)),
               generator_classes: class_flags,
               real_data_points: points,
               real_data_classes: class_flags,
               is_real: np.concatenate([np.zeros((batch_size,1)), np.ones((batch_size,1))])}
    if step%2 == 0:
    if step%1000 == 999:
        print(step, np.mean(discrim_error), np.mean(gen_error))
        discrim_error = []
        gen_error = []

WARNING:tensorflow:From /root/miniconda/lib/python3.6/site-packages/deepchem/models/tensorgraph/tensor_graph.py:714: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

WARNING:tensorflow:From /tensorflow-1.15.2/python3.6/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
WARNING:tensorflow:From /root/miniconda/lib/python3.6/site-packages/deepchem/models/tensorgraph/layers.py:1634: The name tf.log is deprecated. Please use tf.math.log instead.

WARNING:tensorflow:From /root/miniconda/lib/python3.6/site-packages/deepchem/models/tensorgraph/tensor_graph.py:727: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.

WARNING:tensorflow:From /root/miniconda/lib/python3.6/site-packages/deepchem/models/optimizers.py:76: The name tf.train.AdamOptimizer is deprecated. Please use tf.compat.v1.train.AdamOptimizer instead.

WARNING:tensorflow:From /root/miniconda/lib/python3.6/site-packages/deepchem/models/tensorgraph/tensor_graph.py:1012: The name tf.get_collection is deprecated. Please use tf.compat.v1.get_collection instead.

WARNING:tensorflow:From /root/miniconda/lib/python3.6/site-packages/deepchem/models/tensorgraph/tensor_graph.py:1012: The name tf.GraphKeys is deprecated. Please use tf.compat.v1.GraphKeys instead.

WARNING:tensorflow:From /root/miniconda/lib/python3.6/site-packages/deepchem/models/tensorgraph/tensor_graph.py:738: The name tf.global_variables_initializer is deprecated. Please use tf.compat.v1.global_variables_initializer instead.

WARNING:tensorflow:From /root/miniconda/lib/python3.6/site-packages/deepchem/models/tensorgraph/tensor_graph.py:748: The name tf.summary.scalar is deprecated. Please use tf.compat.v1.summary.scalar instead.

999 0.5156213084459305 0.37282696121931075
1999 0.39635649234056475 0.6554632024765015
2999 0.4816185410916805 0.6439448493719101
3999 0.6881231372356414 0.41854076969623566
4999 0.6954806981682777 0.36900784534215925
5999 0.6934329395890236 0.34684676861763003
6999 0.6871857723593712 0.3469327309727669
7999 0.6882104944586754 0.35844097477197645
8999 0.6879851130247117 0.34883454167842864
9999 0.6891423400640487 0.3533225782513619
10999 0.6890938600897789 0.352202350795269
11999 0.6911078352332115 0.3480358254909515
12999 0.6913300577402115 0.34874600952863694
13999 0.6922475056052207 0.34867041957378386
14999 0.691593163728714 0.34903139680624007
15999 0.6911602554917335 0.35044702333211897
16999 0.6909645751714707 0.35226673740148545
17999 0.6911768457889557 0.3513581330180168
18999 0.6894893513917923 0.3482932530641556
19999 0.6915659754276275 0.35546432530879973

Have the trained model generate some data, and see how well it matches the training distribution we plotted before.

In [8]:
classes, points = generate_data(1000)
feed_dict = {random_in: np.random.random((1000, 10)),
             generator_classes: dc.metrics.to_one_hot(classes, n_classes)}
gen_points = model.predict_on_generator([feed_dict])
plot.scatter(x=gen_points[:,0], y=gen_points[:,1], c=classes)

<matplotlib.collections.PathCollection at 0x7f890a003d68>

Congratulations! Time to join the Community!

Congratulations on completing this tutorial notebook! If you enjoyed working through the tutorial, and want to continue working with DeepChem, we encourage you to finish the rest of the tutorials in this series. You can also help the DeepChem community in the following ways:

Star DeepChem on GitHub

This helps build awareness of the DeepChem project and the tools for open source drug discovery that we're trying to build.

Join the DeepChem Gitter

The DeepChem Gitter hosts a number of scientists, developers, and enthusiasts interested in deep learning for the life sciences. Join the conversation!