Test my idea of train 2 pairs of Gans at the same time


In [1]:
%matplotlib inline

import pickle as pkl
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

In [2]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data')


Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz

In [3]:
def model_inputs(real_dim, z_dim):
    inputs_real = tf.placeholder(tf.float32, (None, real_dim), name='input_real')
    inputs_z = tf.placeholder(tf.float32, (None, z_dim), name='input_z')
    
    return inputs_real, inputs_z

In [4]:
def generator(z, out_dim, n_units=128, reuse=False,  alpha=0.01):
    ''' Build the generator network.
    
        Arguments
        ---------
        z : Input tensor for the generator
        out_dim : Shape of the generator output
        n_units : Number of units in hidden layer
        reuse : Reuse the variables with tf.variable_scope
        alpha : leak parameter for leaky ReLU
        
        Returns
        -------
        out, logits: 
    '''
    with tf.variable_scope('generator', reuse=reuse): # finish this
        # Hidden layer
        h1 = tf.layers.dense(z, n_units, activation=None)
        # Leaky ReLU
        h1 = tf.maximum(alpha * h1, h1)
        
        # Logits and tanh output
        logits = tf.layers.dense(z, out_dim, activation=None)
        out = tf.tanh(logits)
        
        return out

In [5]:
def discriminator(x, n_units=128, reuse=False, alpha=0.01):
    ''' Build the discriminator network.
    
        Arguments
        ---------
        x : Input tensor for the discriminator
        n_units: Number of units in hidden layer
        reuse : Reuse the variables with tf.variable_scope
        alpha : leak parameter for leaky ReLU
        
        Returns
        -------
        out, logits: 
    '''
    with tf.variable_scope('discriminator', reuse=reuse): # finish this
        # Hidden layer
        h1 = tf.layers.dense(x, n_units, activation=None)
        # Leaky ReLU
        h1 = tf.maximum(h1 * alpha, h1)
        
        logits = tf.layers.dense(h1, 1, activation=None)
        out = tf.sigmoid(logits)
        
        return out, logits

Hyperparameter


In [6]:
# Size of input image to discriminator
input_size = 784 # 28x28 MNIST images flattened
# Size of latent vector to generator
z_size = 100
# Sizes of hidden layers in generator and discriminator
g_hidden_size = 128
d_hidden_size = 128
# Leak factor for leaky ReLU
alpha = 0.01
# Label smoothing 
smooth = 0.1

Build Networks


In [24]:
tf.reset_default_graph()

input_real, input_z = model_inputs(input_size, z_size)

# Generative network 1
with tf.variable_scope('1'):
    g_model_1 = generator(input_z, input_size, n_units=g_hidden_size, alpha=alpha)
    
# Generative network 2
with tf.variable_scope('2'):
    g_model_2 = generator(input_z, input_size, n_units=g_hidden_size, alpha=alpha)
    
# An empty network for buffer parameter, called generative network 3
with tf.variable_scope('3'):
    g_model_3 = generator(input_z, input_size, n_units=g_hidden_size, alpha=alpha)
    
# Discriminator network 1
with tf.variable_scope('1'):
    d_model_real_1, d_logits_real_1 = discriminator(input_real, d_hidden_size, reuse=False, alpha=alpha)
    d_model_fake_1, d_logits_fake_1 = discriminator(g_model_1, d_hidden_size, reuse=True, alpha=alpha)
    
# Discriminator network 2
with tf.variable_scope('2'):
    d_model_real_2, d_logits_real_2 = discriminator(input_real, d_hidden_size, reuse=False, alpha=alpha)
    d_model_fake_2, d_logits_fake_2 = discriminator(g_model_2, d_hidden_size, reuse=True, alpha=alpha)

In [25]:
# Calculate losses 1
d_loss_real_1 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
    logits=d_logits_real_1, labels=tf.ones_like(d_logits_real_1) * (1 - smooth)))

d_loss_fake_1 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
    logits=d_logits_fake_1, labels=tf.zeros_like(d_logits_fake_1)))

d_loss_1 = d_loss_real_1 + d_loss_fake_1

g_loss_1 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
    logits=d_logits_fake_1, labels=tf.ones_like(d_logits_fake_1)))

# Calculate losses 2
d_loss_real_2 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
    logits=d_logits_real_2, labels=tf.ones_like(d_logits_real_2) * (1 - smooth)))

d_loss_fake_2 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
    logits=d_logits_fake_2, labels=tf.zeros_like(d_logits_fake_2)))

d_loss_2 = d_loss_real_2 + d_loss_fake_2

g_loss_2 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
    logits=d_logits_fake_2, labels=tf.ones_like(d_logits_fake_2)))

In [26]:
# Optimizers
learning_rate = 0.002

# Get the trainable_variables, split into G and D parts
t_vars = tf.trainable_variables()
g_vars_1 = [var for var in t_vars if var.name.startswith('1/generator')]
d_vars_1 = [var for var in t_vars if var.name.startswith('1/discriminator')]

g_vars_2 = [var for var in t_vars if var.name.startswith('2/generator')]
d_vars_2 = [var for var in t_vars if var.name.startswith('2/discriminator')]

# Get the variable of the buffering network
g_vars_3 = [var for var in t_vars if var.name.startswith('3/generator')]

d_train_opt_1 = tf.train.AdamOptimizer(learning_rate).minimize(d_loss_1, var_list=d_vars_1)
g_train_opt_1 = tf.train.AdamOptimizer(learning_rate).minimize(g_loss_1, var_list=g_vars_1)

d_train_opt_2 = tf.train.AdamOptimizer(learning_rate).minimize(d_loss_2, var_list=d_vars_2)
g_train_opt_2 = tf.train.AdamOptimizer(learning_rate).minimize(d_loss_2, var_list=g_vars_2)

In [27]:
# Swap operation
swap_ops = []
# First, assign the values of weight of generator 1 to 3
for g_var_1, g_var_3 in zip(sorted(g_vars_1, key=lambda v: v.name),
                            sorted(g_vars_3, key=lambda v: v.name)):
    swap_ops.append(g_var_3.assign(g_var_1))
    
# Then assign the values of weight of generator 2 to 1
for g_var_2, g_var_1 in zip(sorted(g_vars_2, key=lambda v: v.name),
                            sorted(g_vars_1, key=lambda v: v.name)):
    swap_ops.append(g_var_1.assign(g_var_2))
    
# Last, assign the values of weight of generator 3 to 2
for g_var_3, g_var_2 in zip(sorted(g_vars_3, key=lambda v: v.name),
                            sorted(g_vars_2, key=lambda v: v.name)):
    swap_ops.append(g_var_2.assign(g_var_3))

Training


In [28]:
batch_size = 100
epochs = 100
swap_every = 1
samples = []
losses_1 = []
saver = tf.train.Saver(var_list = g_vars_1)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for e in range(epochs):
        
        #swap the weights
        if e != 0 and e % swap_every == 0:
            sess.run(swap_ops)
        
        for ii in range(mnist.train.num_examples//batch_size):
            batch = mnist.train.next_batch(batch_size)
            
            # Get images, reshape and rescale to pass to D
            batch_images = batch[0].reshape((batch_size, 784))
            batch_images = batch_images*2 - 1
            
            # Sample random noise for G
            batch_z = np.random.uniform(-1, 1, size=(batch_size, z_size))
            
            # Run optimizers
            _ = sess.run([d_train_opt_1, d_train_opt_2], feed_dict={input_real: batch_images, input_z: batch_z})
            _ = sess.run([g_train_opt_1, g_train_opt_2], feed_dict={input_z: batch_z})
        
        # At the end of each epoch, get the losses and print them out
        train_loss_d_1 = sess.run(d_loss_1, {input_z: batch_z, input_real: batch_images})
        train_loss_g_1 = g_loss_1.eval({input_z: batch_z})
            
        print("Epoch {}/{}...".format(e+1, epochs),
              "Discriminator Loss: {:.4f}...".format(train_loss_d_1),
              "Generator Loss: {:.4f}".format(train_loss_g_1))    
        # Save losses to view after training
        losses_1.append((train_loss_d_1, train_loss_g_1))
        
        # Sample from generator as we're training for viewing afterwards
        sample_z = np.random.uniform(-1, 1, size=(16, z_size))
        with tf.variable_scope('1'):
            gen_samples_1 = sess.run(
                           generator(input_z, input_size, n_units=g_hidden_size, reuse=True, alpha=alpha),
                           feed_dict={input_z: sample_z})
            samples.append(gen_samples_1)
            saver.save(sess, './checkpoints/generator.ckpt')

# Save training generator samples
with open('train_samples.pkl', 'wb') as f:
    pkl.dump(samples, f)


Epoch 1/100... Discriminator Loss: 0.3757... Generator Loss: 3.9855
Epoch 2/100... Discriminator Loss: 0.4653... Generator Loss: 2.9071
Epoch 3/100... Discriminator Loss: 0.3761... Generator Loss: 3.8936
Epoch 4/100... Discriminator Loss: 0.4102... Generator Loss: 3.2064
Epoch 5/100... Discriminator Loss: 0.6615... Generator Loss: 3.4003
Epoch 6/100... Discriminator Loss: 0.4584... Generator Loss: 3.2772
Epoch 7/100... Discriminator Loss: 0.4844... Generator Loss: 2.8434
Epoch 8/100... Discriminator Loss: 0.5668... Generator Loss: 2.7142
Epoch 9/100... Discriminator Loss: 0.8430... Generator Loss: 1.9150
Epoch 10/100... Discriminator Loss: 0.7352... Generator Loss: 3.1392
Epoch 11/100... Discriminator Loss: 0.8525... Generator Loss: 2.9962
Epoch 12/100... Discriminator Loss: 0.7526... Generator Loss: 2.7717
Epoch 13/100... Discriminator Loss: 0.6751... Generator Loss: 2.5483
Epoch 14/100... Discriminator Loss: 0.8982... Generator Loss: 2.0455
Epoch 15/100... Discriminator Loss: 1.1173... Generator Loss: 1.8504
Epoch 16/100... Discriminator Loss: 0.8552... Generator Loss: 2.6219
Epoch 17/100... Discriminator Loss: 0.8394... Generator Loss: 2.6981
Epoch 18/100... Discriminator Loss: 0.8045... Generator Loss: 2.0370
Epoch 19/100... Discriminator Loss: 0.9427... Generator Loss: 3.3688
Epoch 20/100... Discriminator Loss: 0.7980... Generator Loss: 2.1409
Epoch 21/100... Discriminator Loss: 0.9339... Generator Loss: 1.7893
Epoch 22/100... Discriminator Loss: 0.8970... Generator Loss: 2.6436
Epoch 23/100... Discriminator Loss: 0.7305... Generator Loss: 2.5124
Epoch 24/100... Discriminator Loss: 0.8755... Generator Loss: 2.8540
Epoch 25/100... Discriminator Loss: 0.7016... Generator Loss: 3.2231
Epoch 26/100... Discriminator Loss: 0.8453... Generator Loss: 3.1169
Epoch 27/100... Discriminator Loss: 0.7414... Generator Loss: 3.1597
Epoch 28/100... Discriminator Loss: 0.7325... Generator Loss: 2.2304
Epoch 29/100... Discriminator Loss: 0.6461... Generator Loss: 2.8740
Epoch 30/100... Discriminator Loss: 0.7139... Generator Loss: 2.5431
Epoch 31/100... Discriminator Loss: 0.8063... Generator Loss: 1.8342
Epoch 32/100... Discriminator Loss: 0.7634... Generator Loss: 2.3265
Epoch 33/100... Discriminator Loss: 0.5854... Generator Loss: 3.5718
Epoch 34/100... Discriminator Loss: 0.6595... Generator Loss: 2.6720
Epoch 35/100... Discriminator Loss: 0.7218... Generator Loss: 2.3076
Epoch 36/100... Discriminator Loss: 0.6637... Generator Loss: 2.4381
Epoch 37/100... Discriminator Loss: 0.5936... Generator Loss: 3.5894
Epoch 38/100... Discriminator Loss: 0.6122... Generator Loss: 2.7846
Epoch 39/100... Discriminator Loss: 0.6728... Generator Loss: 2.6127
Epoch 40/100... Discriminator Loss: 0.6174... Generator Loss: 4.2505
Epoch 41/100... Discriminator Loss: 0.6050... Generator Loss: 3.3698
Epoch 42/100... Discriminator Loss: 0.5566... Generator Loss: 3.2257
Epoch 43/100... Discriminator Loss: 0.6747... Generator Loss: 2.5823
Epoch 44/100... Discriminator Loss: 0.5057... Generator Loss: 3.2983
Epoch 45/100... Discriminator Loss: 0.5753... Generator Loss: 3.8629
Epoch 46/100... Discriminator Loss: 0.4892... Generator Loss: 3.6091
Epoch 47/100... Discriminator Loss: 0.5402... Generator Loss: 3.4772
Epoch 48/100... Discriminator Loss: 0.5713... Generator Loss: 3.3652
Epoch 49/100... Discriminator Loss: 0.5634... Generator Loss: 3.4757
Epoch 50/100... Discriminator Loss: 0.5270... Generator Loss: 3.2793
Epoch 51/100... Discriminator Loss: 0.5633... Generator Loss: 3.1725
Epoch 52/100... Discriminator Loss: 0.5970... Generator Loss: 3.2997
Epoch 53/100... Discriminator Loss: 0.5580... Generator Loss: 3.2613
Epoch 54/100... Discriminator Loss: 0.5150... Generator Loss: 3.7703
Epoch 55/100... Discriminator Loss: 0.5456... Generator Loss: 3.3902
Epoch 56/100... Discriminator Loss: 0.5335... Generator Loss: 3.8166
Epoch 57/100... Discriminator Loss: 0.4989... Generator Loss: 3.5395
Epoch 58/100... Discriminator Loss: 0.5642... Generator Loss: 2.8278
Epoch 59/100... Discriminator Loss: 0.5123... Generator Loss: 3.8302
Epoch 60/100... Discriminator Loss: 0.4958... Generator Loss: 3.2975
Epoch 61/100... Discriminator Loss: 0.5325... Generator Loss: 3.0051
Epoch 62/100... Discriminator Loss: 0.5604... Generator Loss: 3.2356
Epoch 63/100... Discriminator Loss: 0.5065... Generator Loss: 3.1153
Epoch 64/100... Discriminator Loss: 0.4688... Generator Loss: 3.6201
Epoch 65/100... Discriminator Loss: 0.5367... Generator Loss: 3.8893
Epoch 66/100... Discriminator Loss: 0.5460... Generator Loss: 3.4578
Epoch 67/100... Discriminator Loss: 0.4953... Generator Loss: 3.8375
Epoch 68/100... Discriminator Loss: 0.5230... Generator Loss: 3.0408
Epoch 69/100... Discriminator Loss: 0.4747... Generator Loss: 3.6777
Epoch 70/100... Discriminator Loss: 0.4652... Generator Loss: 3.9949
Epoch 71/100... Discriminator Loss: 0.4999... Generator Loss: 3.3991
Epoch 72/100... Discriminator Loss: 0.5275... Generator Loss: 3.4254
Epoch 73/100... Discriminator Loss: 0.4793... Generator Loss: 4.1214
Epoch 74/100... Discriminator Loss: 0.4949... Generator Loss: 3.7742
Epoch 75/100... Discriminator Loss: 0.4892... Generator Loss: 3.6412
Epoch 76/100... Discriminator Loss: 0.4631... Generator Loss: 3.3791
Epoch 77/100... Discriminator Loss: 0.4732... Generator Loss: 4.2137
Epoch 78/100... Discriminator Loss: 0.4640... Generator Loss: 3.6214
Epoch 79/100... Discriminator Loss: 0.5032... Generator Loss: 3.9282
Epoch 80/100... Discriminator Loss: 0.5304... Generator Loss: 3.9594
Epoch 81/100... Discriminator Loss: 0.4396... Generator Loss: 4.0431
Epoch 82/100... Discriminator Loss: 0.4697... Generator Loss: 3.7697
Epoch 83/100... Discriminator Loss: 0.4560... Generator Loss: 3.6247
Epoch 84/100... Discriminator Loss: 0.5060... Generator Loss: 3.7933
Epoch 85/100... Discriminator Loss: 0.5003... Generator Loss: 3.5883
Epoch 86/100... Discriminator Loss: 0.4907... Generator Loss: 3.7253
Epoch 87/100... Discriminator Loss: 0.5040... Generator Loss: 3.7299
Epoch 88/100... Discriminator Loss: 0.4863... Generator Loss: 3.6790
Epoch 89/100... Discriminator Loss: 0.4404... Generator Loss: 4.1080
Epoch 90/100... Discriminator Loss: 0.4992... Generator Loss: 3.9743
Epoch 91/100... Discriminator Loss: 0.4631... Generator Loss: 4.2942
Epoch 92/100... Discriminator Loss: 0.5506... Generator Loss: 3.3433
Epoch 93/100... Discriminator Loss: 0.4493... Generator Loss: 4.4787
Epoch 94/100... Discriminator Loss: 0.4905... Generator Loss: 3.7438
Epoch 95/100... Discriminator Loss: 0.4757... Generator Loss: 4.3926
Epoch 96/100... Discriminator Loss: 0.5695... Generator Loss: 3.3693
Epoch 97/100... Discriminator Loss: 0.4859... Generator Loss: 3.8980
Epoch 98/100... Discriminator Loss: 0.4822... Generator Loss: 3.4105
Epoch 99/100... Discriminator Loss: 0.4658... Generator Loss: 3.8351
Epoch 100/100... Discriminator Loss: 0.5335... Generator Loss: 3.2338

Training Loss


In [29]:
%matplotlib inline

import matplotlib.pyplot as plt

In [31]:
fig, ax = plt.subplots()
losses = np.array(losses_1)
plt.plot(losses.T[0], label='Discriminator')
plt.plot(losses.T[1], label='Generator')
plt.title("Training Losses")
plt.legend()


Out[31]:
<matplotlib.legend.Legend at 0x7fb3ee76b9b0>

Generate sample from training


In [32]:
def view_samples(epoch, samples):
    fig, axes = plt.subplots(figsize=(7,7), nrows=4, ncols=4, sharey=True, sharex=True)
    for ax, img in zip(axes.flatten(), samples[epoch]):
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        im = ax.imshow(img.reshape((28,28)), cmap='Greys_r')
    
    return fig, axes

In [33]:
# Load samples from generator taken while training
with open('train_samples.pkl', 'rb') as f:
    samples = pkl.load(f)

In [34]:
_ = view_samples(-1, samples)



In [35]:
rows, cols = 10, 6
fig, axes = plt.subplots(figsize=(7,12), nrows=rows, ncols=cols, sharex=True, sharey=True)

for sample, ax_row in zip(samples[::int(len(samples)/rows)], axes):
    for img, ax in zip(sample[::int(len(sample)/cols)], ax_row):
        ax.imshow(img.reshape((28,28)), cmap='Greys_r')
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)


Sampling from generator


In [37]:
saver = tf.train.Saver(var_list=g_vars_1)
with tf.Session() as sess:
    saver.restore(sess, tf.train.latest_checkpoint('checkpoints'))
    sample_z = np.random.uniform(-1, 1, size=(16, z_size))
    with tf.variable_scope('1'):
        gen_samples = sess.run(
                       generator(input_z, input_size, n_units=g_hidden_size, reuse=True, alpha=alpha),
                       feed_dict={input_z: sample_z})
view_samples(0, [gen_samples])


Out[37]:
(<matplotlib.figure.Figure at 0x7fb3ea9379e8>,
 array([[<matplotlib.axes._subplots.AxesSubplot object at 0x7fb3fb952518>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7fb3ea6e7898>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7fb3f05dfac8>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7fb3f0549438>],
        [<matplotlib.axes._subplots.AxesSubplot object at 0x7fb3f051d5f8>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7fb3f0482748>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7fb3f0454e80>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7fb3f03c3320>],
        [<matplotlib.axes._subplots.AxesSubplot object at 0x7fb3f0478400>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7fb3f037cdd8>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7fb3f02e3f60>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7fb3f02be6d8>],
        [<matplotlib.axes._subplots.AxesSubplot object at 0x7fb3f0227860>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7fb3f01f7ac8>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7fb3f0160c50>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7fb3f013b0f0>]], dtype=object))