Tutorial Part 16: Conditional Generative Adversarial Network

Note: This example implements a GAN from scratch. The same model could be implemented much more easily with the dc.models.GAN class. See the MNIST GAN notebook for an example of using that class. It can still be useful to know how to implement a GAN from scratch for advanced situations that are beyond the scope of what the standard GAN class supports.

A Generative Adversarial Network (GAN) is a type of generative model. It consists of two parts called the "generator" and the "discriminator". The generator takes random values as input and transforms them into an output that (hopefully) resembles the training data. The discriminator takes a set of samples as input and tries to distinguish the real training samples from the ones created by the generator. Both of them are trained together. The discriminator tries to get better and better at telling real from false data, while the generator tries to get better and better at fooling the discriminator.

A Conditional GAN (CGAN) allows additional inputs to the generator and discriminator that their output is conditioned on. For example, this might be a class label, and the GAN tries to learn how the data distribution varies between classes.

Colab

This tutorial and the rest in this sequence are designed to be done in Google colab. If you'd like to open this notebook in colab, you can use the following link.

Setup

To run DeepChem within Colab, you'll need to run the following cell of installation commands. This will take about 5 minutes to run to completion and install your environment.


In [1]:
%tensorflow_version 1.x
!curl -Lo deepchem_installer.py https://raw.githubusercontent.com/deepchem/deepchem/master/scripts/colab_install.py
import deepchem_installer
%time deepchem_installer.install(version='2.3.0')


TensorFlow 1.x selected.
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  2814  100  2814    0     0  90774      0 --:--:-- --:--:-- --:--:-- 90774
add /root/miniconda/lib/python3.6/site-packages to PYTHONPATH
python version: 3.6.9
fetching installer from https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh
done
installing miniconda to /root/miniconda
done
installing deepchem
done
/usr/local/lib/python3.6/dist-packages/sklearn/externals/joblib/__init__.py:15: FutureWarning: sklearn.externals.joblib is deprecated in 0.21 and will be removed in 0.23. Please import this functionality directly from joblib, which can be installed with: pip install joblib. If this warning is raised when loading pickled models, you may need to re-serialize those models with scikit-learn 0.21+.
  warnings.warn(msg, category=FutureWarning)
WARNING:tensorflow:
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

deepchem-2.3.0 installation finished!
CPU times: user 2.58 s, sys: 541 ms, total: 3.12 s
Wall time: 4min 12s

For this example, we will create a data distribution consisting of a set of ellipses in 2D, each with a random position, shape, and orientation. Each class corresponds to a different ellipse. Let's randomly generate the ellipses.


In [0]:
import deepchem as dc
import numpy as np
import tensorflow as tf

n_classes = 4
class_centers = np.random.uniform(-4, 4, (n_classes, 2))
class_transforms = []
for i in range(n_classes):
    xscale = np.random.uniform(0.5, 2)
    yscale = np.random.uniform(0.5, 2)
    angle = np.random.uniform(0, np.pi)
    m = [[xscale*np.cos(angle), -yscale*np.sin(angle)],
         [xscale*np.sin(angle), yscale*np.cos(angle)]]
    class_transforms.append(m)
class_transforms = np.array(class_transforms)

This function generates random data from the distribution. For each point it chooses a random class, then a random position in that class' ellipse.


In [0]:
def generate_data(n_points):
    classes = np.random.randint(n_classes, size=n_points)
    r = np.random.random(n_points)
    angle = 2*np.pi*np.random.random(n_points)
    points = (r*np.array([np.cos(angle), np.sin(angle)])).T
    points = np.einsum('ijk,ik->ij', class_transforms[classes], points)
    points += class_centers[classes]
    return classes, points

Let's plot a bunch of random points drawn from this distribution to see what it looks like. Points are colored based on their class label.


In [4]:
%matplotlib inline
import matplotlib.pyplot as plot
classes, points = generate_data(1000)
plot.scatter(x=points[:,0], y=points[:,1], c=classes)


Out[4]:
<matplotlib.collections.PathCollection at 0x7ff8cf6e5e10>

Now let's create the model for our CGAN.


In [0]:
import deepchem.models.tensorgraph.layers as layers
model = dc.models.TensorGraph(learning_rate=1e-4, use_queue=False)

# Inputs to the model

random_in = layers.Feature(shape=(None, 10)) # Random input to the generator
generator_classes = layers.Feature(shape=(None, n_classes)) # The classes of the generated samples
real_data_points = layers.Feature(shape=(None, 2)) # The training samples
real_data_classes = layers.Feature(shape=(None, n_classes)) # The classes of the training samples
is_real = layers.Weights(shape=(None, 1)) # Flags to distinguish real from generated samples

# The generator

gen_in = layers.Concat([random_in, generator_classes])
gen_dense1 = layers.Dense(30, in_layers=gen_in, activation_fn=tf.nn.relu)
gen_dense2 = layers.Dense(30, in_layers=gen_dense1, activation_fn=tf.nn.relu)
generator_points = layers.Dense(2, in_layers=gen_dense2)
model.add_output(generator_points)

# The discriminator

all_points = layers.Concat([generator_points, real_data_points], axis=0)
all_classes = layers.Concat([generator_classes, real_data_classes], axis=0)
discrim_in = layers.Concat([all_points, all_classes])
discrim_dense1 = layers.Dense(30, in_layers=discrim_in, activation_fn=tf.nn.relu)
discrim_dense2 = layers.Dense(30, in_layers=discrim_dense1, activation_fn=tf.nn.relu)
discrim_prob = layers.Dense(1, in_layers=discrim_dense2, activation_fn=tf.sigmoid)

We'll use different loss functions for training the generator and discriminator. The discriminator outputs its predictions in the form of a probability that each sample is a real sample (that is, that it came from the training set rather than the generator). Its loss consists of two terms. The first term tries to maximize the output probability for real data, and the second term tries to minimize the output probability for generated samples. The loss function for the generator is just a single term: it tries to maximize the discriminator's output probability for generated samples.

For each one, we create a "submodel" specifying a set of layers that will be optimized based on a loss function.


In [0]:
# Discriminator

discrim_real_data_loss = -layers.Log(discrim_prob+1e-10) * is_real
discrim_gen_data_loss = -layers.Log(1-discrim_prob+1e-10) * (1-is_real)
discrim_loss = layers.ReduceMean(discrim_real_data_loss + discrim_gen_data_loss)
discrim_submodel = model.create_submodel(layers=[discrim_dense1, discrim_dense2, discrim_prob], loss=discrim_loss)

# Generator

gen_loss = -layers.ReduceMean(layers.Log(discrim_prob+1e-10) * (1-is_real))
gen_submodel = model.create_submodel(layers=[gen_dense1, gen_dense2, generator_points], loss=gen_loss)

Now to fit the model. Here are some important points to notice about the code.

  • We use fit_generator() to train only a single batch at a time, and we alternate between the discriminator and the generator. That way. both parts of the model improve together.
  • We only train the generator half as often as the discriminator. On this particular model, that gives much better results. You will often need to adjust (# of discriminator steps)/(# of generator steps) to get good results on a given problem.
  • We disable checkpointing by specifying checkpoint_interval=0. Since each call to fit_generator() includes only a single batch, it would otherwise save a checkpoint to disk after every batch, which would be very slow. If this were a real project and not just an example, we would want to occasionally call model.save_checkpoint() to write checkpoints at a reasonable interval.

In [7]:
batch_size = model.batch_size
discrim_error = []
gen_error = []
for step in range(20000):
    classes, points = generate_data(batch_size)
    class_flags = dc.metrics.to_one_hot(classes, n_classes)
    feed_dict={random_in: np.random.random((batch_size, 10)),
               generator_classes: class_flags,
               real_data_points: points,
               real_data_classes: class_flags,
               is_real: np.concatenate([np.zeros((batch_size,1)), np.ones((batch_size,1))])}
    discrim_error.append(model.fit_generator([feed_dict],
                                             submodel=discrim_submodel,
                                             checkpoint_interval=0))
    if step%2 == 0:
        gen_error.append(model.fit_generator([feed_dict],
                                             submodel=gen_submodel,
                                             checkpoint_interval=0))
    if step%1000 == 999:
        print(step, np.mean(discrim_error), np.mean(gen_error))
        discrim_error = []
        gen_error = []


WARNING:tensorflow:From /root/miniconda/lib/python3.6/site-packages/deepchem/models/tensorgraph/tensor_graph.py:714: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

WARNING:tensorflow:From /tensorflow-1.15.2/python3.6/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
WARNING:tensorflow:From /root/miniconda/lib/python3.6/site-packages/deepchem/models/tensorgraph/layers.py:1634: The name tf.log is deprecated. Please use tf.math.log instead.

WARNING:tensorflow:From /root/miniconda/lib/python3.6/site-packages/deepchem/models/tensorgraph/tensor_graph.py:727: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.

WARNING:tensorflow:From /root/miniconda/lib/python3.6/site-packages/deepchem/models/optimizers.py:76: The name tf.train.AdamOptimizer is deprecated. Please use tf.compat.v1.train.AdamOptimizer instead.

WARNING:tensorflow:From /root/miniconda/lib/python3.6/site-packages/deepchem/models/tensorgraph/tensor_graph.py:1012: The name tf.get_collection is deprecated. Please use tf.compat.v1.get_collection instead.

WARNING:tensorflow:From /root/miniconda/lib/python3.6/site-packages/deepchem/models/tensorgraph/tensor_graph.py:1012: The name tf.GraphKeys is deprecated. Please use tf.compat.v1.GraphKeys instead.

WARNING:tensorflow:From /root/miniconda/lib/python3.6/site-packages/deepchem/models/tensorgraph/tensor_graph.py:738: The name tf.global_variables_initializer is deprecated. Please use tf.compat.v1.global_variables_initializer instead.

WARNING:tensorflow:From /root/miniconda/lib/python3.6/site-packages/deepchem/models/tensorgraph/tensor_graph.py:748: The name tf.summary.scalar is deprecated. Please use tf.compat.v1.summary.scalar instead.

999 0.40659638777375223 0.46297103410959245
1999 0.3384278678447008 0.8597527861595153
2999 0.28886374438554047 0.9156105797290802
3999 0.42119870174676177 0.840492086827755
4999 0.6085479544401169 0.45372276908159254
5999 0.7160498830676079 0.35652754533290865
6999 0.6727188802361488 0.36548557299375534
7999 0.6745303119421006 0.36139536756277085
8999 0.6672571448087692 0.3725837562680244
9999 0.6735138981938362 0.36844821578264236
10999 0.6764357801675797 0.36131750684976577
11999 0.6835235329866409 0.3585679198503494
12999 0.6849101437330246 0.3555546105504036
13999 0.6862603163719178 0.35470631182193757
14999 0.6857899969816208 0.3557598451972008
15999 0.6868707528114318 0.35640183770656586
16999 0.6868409720659256 0.3557077826857567
17999 0.6868168808817864 0.3548824065327644
18999 0.6882137333750725 0.35393582719564437
19999 0.6891591399312019 0.3511381688117981

Have the trained model generate some data, and see how well it matches the training distribution we plotted before.


In [8]:
classes, points = generate_data(1000)
feed_dict = {random_in: np.random.random((1000, 10)),
             generator_classes: dc.metrics.to_one_hot(classes, n_classes)}
gen_points = model.predict_on_generator([feed_dict])
plot.scatter(x=gen_points[:,0], y=gen_points[:,1], c=classes)


Out[8]:
<matplotlib.collections.PathCollection at 0x7ff8c0208b38>

Congratulations! Time to join the Community!

Congratulations on completing this tutorial notebook! If you enjoyed working through the tutorial, and want to continue working with DeepChem, we encourage you to finish the rest of the tutorials in this series. You can also help the DeepChem community in the following ways:

Star DeepChem on GitHub

This helps build awareness of the DeepChem project and the tools for open source drug discovery that we're trying to build.

Join the DeepChem Gitter

The DeepChem Gitter hosts a number of scientists, developers, and enthusiasts interested in deep learning for the life sciences. Join the conversation!