Test my idea of train 2 pairs of Gans at the same time


In [1]:
%matplotlib inline

import pickle as pkl
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

In [2]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data')


Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz

In [3]:
def model_inputs(real_dim, z_dim):
    inputs_real = tf.placeholder(tf.float32, (None, real_dim), name='input_real')
    inputs_z = tf.placeholder(tf.float32, (None, z_dim), name='input_z')
    
    return inputs_real, inputs_z

In [4]:
def generator(z, out_dim, n_units=128, reuse=False,  alpha=0.01):
    ''' Build the generator network.
    
        Arguments
        ---------
        z : Input tensor for the generator
        out_dim : Shape of the generator output
        n_units : Number of units in hidden layer
        reuse : Reuse the variables with tf.variable_scope
        alpha : leak parameter for leaky ReLU
        
        Returns
        -------
        out, logits: 
    '''
    with tf.variable_scope('generator', reuse=reuse): # finish this
        # Hidden layer
        h1 = tf.layers.dense(z, n_units, activation=None)
        # Leaky ReLU
        h1 = tf.maximum(alpha * h1, h1)
        
        # Logits and tanh output
        logits = tf.layers.dense(z, out_dim, activation=None)
        out = tf.tanh(logits)
        
        return out

In [5]:
def discriminator(x, n_units=128, reuse=False, alpha=0.01):
    ''' Build the discriminator network.
    
        Arguments
        ---------
        x : Input tensor for the discriminator
        n_units: Number of units in hidden layer
        reuse : Reuse the variables with tf.variable_scope
        alpha : leak parameter for leaky ReLU
        
        Returns
        -------
        out, logits: 
    '''
    with tf.variable_scope('discriminator', reuse=reuse): # finish this
        # Hidden layer
        h1 = tf.layers.dense(x, n_units, activation=None)
        # Leaky ReLU
        h1 = tf.maximum(h1 * alpha, h1)
        
        logits = tf.layers.dense(h1, 1, activation=None)
        out = tf.sigmoid(logits)
        
        return out, logits

Hyperparameter


In [6]:
# Size of input image to discriminator
input_size = 784 # 28x28 MNIST images flattened
# Size of latent vector to generator
z_size = 100
# Sizes of hidden layers in generator and discriminator
g_hidden_size = 128
d_hidden_size = 128
# Leak factor for leaky ReLU
alpha = 0.01
# Label smoothing 
smooth = 0.1

Build Networks


In [7]:
tf.reset_default_graph()

input_real, input_z = model_inputs(input_size, z_size)

# Generative network 1
with tf.variable_scope('1'):
    g_model_1 = generator(input_z, input_size, n_units=g_hidden_size, alpha=alpha)
    
# Generative network 2
with tf.variable_scope('2'):
    g_model_2 = generator(input_z, input_size, n_units=g_hidden_size, alpha=alpha)
    
# An empty network for buffer parameter, called generative network 3
with tf.variable_scope('3'):
    g_model_3 = generator(input_z, input_size, n_units=g_hidden_size, alpha=alpha)
    
# Discriminator network 1
with tf.variable_scope('1'):
    d_model_real_1, d_logits_real_1 = discriminator(input_real, d_hidden_size, reuse=False, alpha=alpha)
    d_model_fake_1, d_logits_fake_1 = discriminator(g_model_1, d_hidden_size, reuse=True, alpha=alpha)
    
# Discriminator network 2
with tf.variable_scope('2'):
    d_model_real_2, d_logits_real_2 = discriminator(input_real, d_hidden_size, reuse=False, alpha=alpha)
    d_model_fake_2, d_logits_fake_2 = discriminator(g_model_2, d_hidden_size, reuse=True, alpha=alpha)

In [8]:
# Calculate losses 1
d_loss_real_1 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
    logits=d_logits_real_1, labels=tf.ones_like(d_logits_real_1) * (1 - smooth)))

d_loss_fake_1 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
    logits=d_logits_fake_1, labels=tf.zeros_like(d_logits_fake_1)))

d_loss_1 = d_loss_real_1 + d_loss_fake_1

g_loss_1 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
    logits=d_logits_fake_1, labels=tf.ones_like(d_logits_fake_1)))

# Calculate losses 2
d_loss_real_2 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
    logits=d_logits_real_2, labels=tf.ones_like(d_logits_real_2) * (1 - smooth)))

d_loss_fake_2 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
    logits=d_logits_fake_2, labels=tf.zeros_like(d_logits_fake_2)))

d_loss_2 = d_loss_real_2 + d_loss_fake_2

g_loss_2 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
    logits=d_logits_fake_2, labels=tf.ones_like(d_logits_fake_2)))

In [9]:
# Optimizers
learning_rate = 0.002

# Get the trainable_variables, split into G and D parts
t_vars = tf.trainable_variables()
g_vars_1 = [var for var in t_vars if var.name.startswith('1/generator')]
d_vars_1 = [var for var in t_vars if var.name.startswith('1/discriminator')]

g_vars_2 = [var for var in t_vars if var.name.startswith('2/generator')]
d_vars_2 = [var for var in t_vars if var.name.startswith('2/discriminator')]

# Get the variable of the buffering network
g_vars_3 = [var for var in t_vars if var.name.startswith('3/generator')]

d_train_opt_1 = tf.train.AdamOptimizer(learning_rate).minimize(d_loss_1, var_list=d_vars_1)
g_train_opt_1 = tf.train.AdamOptimizer(learning_rate).minimize(g_loss_1, var_list=g_vars_1)

d_train_opt_2 = tf.train.AdamOptimizer(learning_rate).minimize(d_loss_2, var_list=d_vars_2)
g_train_opt_2 = tf.train.AdamOptimizer(learning_rate).minimize(d_loss_2, var_list=g_vars_2)

In [10]:
# Swap operation
swap_ops = []
# First, assign the values of weight of generator 1 to 3
for g_var_1, g_var_3 in zip(sorted(g_vars_1, key=lambda v: v.name),
                            sorted(g_vars_3, key=lambda v: v.name)):
    swap_ops.append(g_var_3.assign(g_var_1))
    
# Then assign the values of weight of generator 2 to 1
for g_var_2, g_var_1 in zip(sorted(g_vars_2, key=lambda v: v.name),
                            sorted(g_vars_1, key=lambda v: v.name)):
    swap_ops.append(g_var_1.assign(g_var_2))
    
# Last, assign the values of weight of generator 3 to 2
for g_var_3, g_var_2 in zip(sorted(g_vars_3, key=lambda v: v.name),
                            sorted(g_vars_2, key=lambda v: v.name)):
    swap_ops.append(g_var_2.assign(g_var_3))

Training


In [11]:
batch_size = 100
epochs = 100
swap_every = 2
samples = []
losses_1 = []
saver = tf.train.Saver(var_list = g_vars_1)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for e in range(epochs):
        
        #swap the weights
        if e != 0 and e % swap_every == 0:
            sess.run(swap_ops)
        
        for ii in range(mnist.train.num_examples//batch_size):
            batch = mnist.train.next_batch(batch_size)
            
            # Get images, reshape and rescale to pass to D
            batch_images = batch[0].reshape((batch_size, 784))
            batch_images = batch_images*2 - 1
            
            # Sample random noise for G
            batch_z = np.random.uniform(-1, 1, size=(batch_size, z_size))
            
            # Run optimizers
            _ = sess.run([d_train_opt_1, d_train_opt_2], feed_dict={input_real: batch_images, input_z: batch_z})
            _ = sess.run([g_train_opt_1, g_train_opt_2], feed_dict={input_z: batch_z})
        
        # At the end of each epoch, get the losses and print them out
        train_loss_d_1 = sess.run(d_loss_1, {input_z: batch_z, input_real: batch_images})
        train_loss_g_1 = g_loss_1.eval({input_z: batch_z})
            
        print("Epoch {}/{}...".format(e+1, epochs),
              "Discriminator Loss: {:.4f}...".format(train_loss_d_1),
              "Generator Loss: {:.4f}".format(train_loss_g_1))    
        # Save losses to view after training
        losses_1.append((train_loss_d_1, train_loss_g_1))
        
        # Sample from generator as we're training for viewing afterwards
        sample_z = np.random.uniform(-1, 1, size=(16, z_size))
        with tf.variable_scope('1'):
            gen_samples_1 = sess.run(
                           generator(input_z, input_size, n_units=g_hidden_size, reuse=True, alpha=alpha),
                           feed_dict={input_z: sample_z})
            samples.append(gen_samples_1)
            saver.save(sess, './checkpoints/generator.ckpt')

# Save training generator samples
with open('train_samples.pkl', 'wb') as f:
    pkl.dump(samples, f)


Epoch 1/100... Discriminator Loss: 0.4009... Generator Loss: 4.0598
Epoch 2/100... Discriminator Loss: 0.4104... Generator Loss: 3.6800
Epoch 3/100... Discriminator Loss: 0.3384... Generator Loss: 5.2342
Epoch 4/100... Discriminator Loss: 0.3656... Generator Loss: 4.0374
Epoch 5/100... Discriminator Loss: 0.4690... Generator Loss: 2.8743
Epoch 6/100... Discriminator Loss: 0.5573... Generator Loss: 2.8412
Epoch 7/100... Discriminator Loss: 0.4467... Generator Loss: 3.4529
Epoch 8/100... Discriminator Loss: 0.8072... Generator Loss: 2.9690
Epoch 9/100... Discriminator Loss: 0.8128... Generator Loss: 2.2297
Epoch 10/100... Discriminator Loss: 0.6784... Generator Loss: 2.9664
Epoch 11/100... Discriminator Loss: 0.6334... Generator Loss: 2.2707
Epoch 12/100... Discriminator Loss: 0.8685... Generator Loss: 2.2955
Epoch 13/100... Discriminator Loss: 0.7108... Generator Loss: 2.7598
Epoch 14/100... Discriminator Loss: 0.9771... Generator Loss: 2.1738
Epoch 15/100... Discriminator Loss: 0.7974... Generator Loss: 2.0513
Epoch 16/100... Discriminator Loss: 1.0217... Generator Loss: 1.7984
Epoch 17/100... Discriminator Loss: 0.6423... Generator Loss: 3.6046
Epoch 18/100... Discriminator Loss: 0.8868... Generator Loss: 1.7227
Epoch 19/100... Discriminator Loss: 0.8293... Generator Loss: 2.6401
Epoch 20/100... Discriminator Loss: 0.9523... Generator Loss: 2.1076
Epoch 21/100... Discriminator Loss: 0.7687... Generator Loss: 2.7457
Epoch 22/100... Discriminator Loss: 0.7048... Generator Loss: 2.5421
Epoch 23/100... Discriminator Loss: 0.8096... Generator Loss: 2.1933
Epoch 24/100... Discriminator Loss: 0.7871... Generator Loss: 2.5559
Epoch 25/100... Discriminator Loss: 0.8016... Generator Loss: 2.1006
Epoch 26/100... Discriminator Loss: 0.7939... Generator Loss: 2.8046
Epoch 27/100... Discriminator Loss: 0.6936... Generator Loss: 3.0539
Epoch 28/100... Discriminator Loss: 0.8164... Generator Loss: 2.0265
Epoch 29/100... Discriminator Loss: 0.7107... Generator Loss: 2.6957
Epoch 30/100... Discriminator Loss: 0.6796... Generator Loss: 2.5540
Epoch 31/100... Discriminator Loss: 0.6140... Generator Loss: 2.7957
Epoch 32/100... Discriminator Loss: 0.6185... Generator Loss: 2.7838
Epoch 33/100... Discriminator Loss: 0.6359... Generator Loss: 3.3414
Epoch 34/100... Discriminator Loss: 0.7562... Generator Loss: 2.3183
Epoch 35/100... Discriminator Loss: 0.6354... Generator Loss: 2.9954
Epoch 36/100... Discriminator Loss: 0.7338... Generator Loss: 2.4577
Epoch 37/100... Discriminator Loss: 0.6839... Generator Loss: 2.1292
Epoch 38/100... Discriminator Loss: 0.6481... Generator Loss: 2.8357
Epoch 39/100... Discriminator Loss: 0.6838... Generator Loss: 2.9790
Epoch 40/100... Discriminator Loss: 0.6385... Generator Loss: 2.7815
Epoch 41/100... Discriminator Loss: 0.7269... Generator Loss: 2.0335
Epoch 42/100... Discriminator Loss: 0.6699... Generator Loss: 2.2821
Epoch 43/100... Discriminator Loss: 0.7470... Generator Loss: 2.8083
Epoch 44/100... Discriminator Loss: 0.6033... Generator Loss: 2.7103
Epoch 45/100... Discriminator Loss: 0.6527... Generator Loss: 3.2372
Epoch 46/100... Discriminator Loss: 0.7616... Generator Loss: 2.5144
Epoch 47/100... Discriminator Loss: 0.7748... Generator Loss: 2.1777
Epoch 48/100... Discriminator Loss: 0.6339... Generator Loss: 2.6534
Epoch 49/100... Discriminator Loss: 0.6870... Generator Loss: 2.6280
Epoch 50/100... Discriminator Loss: 0.7089... Generator Loss: 2.3301
Epoch 51/100... Discriminator Loss: 0.7386... Generator Loss: 2.7281
Epoch 52/100... Discriminator Loss: 0.6190... Generator Loss: 2.6396
Epoch 53/100... Discriminator Loss: 0.6073... Generator Loss: 3.1413
Epoch 54/100... Discriminator Loss: 0.5892... Generator Loss: 2.7362
Epoch 55/100... Discriminator Loss: 0.6796... Generator Loss: 2.9846
Epoch 56/100... Discriminator Loss: 0.5863... Generator Loss: 3.0761
Epoch 57/100... Discriminator Loss: 0.6380... Generator Loss: 3.0471
Epoch 58/100... Discriminator Loss: 0.6244... Generator Loss: 3.6599
Epoch 59/100... Discriminator Loss: 0.6480... Generator Loss: 3.4138
Epoch 60/100... Discriminator Loss: 0.6406... Generator Loss: 2.8782
Epoch 61/100... Discriminator Loss: 0.5686... Generator Loss: 3.1121
Epoch 62/100... Discriminator Loss: 0.5520... Generator Loss: 3.4075
Epoch 63/100... Discriminator Loss: 0.6158... Generator Loss: 2.4138
Epoch 64/100... Discriminator Loss: 0.5775... Generator Loss: 2.9708
Epoch 65/100... Discriminator Loss: 0.6775... Generator Loss: 3.3417
Epoch 66/100... Discriminator Loss: 0.5791... Generator Loss: 3.0862
Epoch 67/100... Discriminator Loss: 0.5802... Generator Loss: 2.7029
Epoch 68/100... Discriminator Loss: 0.5226... Generator Loss: 3.0028
Epoch 69/100... Discriminator Loss: 0.6605... Generator Loss: 2.6302
Epoch 70/100... Discriminator Loss: 0.5728... Generator Loss: 2.9240
Epoch 71/100... Discriminator Loss: 0.5535... Generator Loss: 3.4677
Epoch 72/100... Discriminator Loss: 0.5025... Generator Loss: 3.2679
Epoch 73/100... Discriminator Loss: 0.5797... Generator Loss: 3.5077
Epoch 74/100... Discriminator Loss: 0.5779... Generator Loss: 3.2572
Epoch 75/100... Discriminator Loss: 0.5540... Generator Loss: 3.1079
Epoch 76/100... Discriminator Loss: 0.5190... Generator Loss: 3.5352
Epoch 77/100... Discriminator Loss: 0.5921... Generator Loss: 2.7035
Epoch 78/100... Discriminator Loss: 0.5327... Generator Loss: 3.1557
Epoch 79/100... Discriminator Loss: 0.6458... Generator Loss: 2.5996
Epoch 80/100... Discriminator Loss: 0.4864... Generator Loss: 3.5296
Epoch 81/100... Discriminator Loss: 0.5046... Generator Loss: 3.3739
Epoch 82/100... Discriminator Loss: 0.5656... Generator Loss: 3.0207
Epoch 83/100... Discriminator Loss: 0.5397... Generator Loss: 3.4638
Epoch 84/100... Discriminator Loss: 0.5454... Generator Loss: 2.9330
Epoch 85/100... Discriminator Loss: 0.5537... Generator Loss: 3.5797
Epoch 86/100... Discriminator Loss: 0.6276... Generator Loss: 2.3383
Epoch 87/100... Discriminator Loss: 0.5196... Generator Loss: 3.5945
Epoch 88/100... Discriminator Loss: 0.5159... Generator Loss: 3.6218
Epoch 89/100... Discriminator Loss: 0.5419... Generator Loss: 3.2057
Epoch 90/100... Discriminator Loss: 0.5268... Generator Loss: 3.5761
Epoch 91/100... Discriminator Loss: 0.4994... Generator Loss: 3.4216
Epoch 92/100... Discriminator Loss: 0.5050... Generator Loss: 3.6190
Epoch 93/100... Discriminator Loss: 0.5731... Generator Loss: 2.6393
Epoch 94/100... Discriminator Loss: 0.4834... Generator Loss: 3.6152
Epoch 95/100... Discriminator Loss: 0.5514... Generator Loss: 2.8033
Epoch 96/100... Discriminator Loss: 0.4863... Generator Loss: 3.8432
Epoch 97/100... Discriminator Loss: 0.5298... Generator Loss: 3.3608
Epoch 98/100... Discriminator Loss: 0.5550... Generator Loss: 3.4308
Epoch 99/100... Discriminator Loss: 0.5486... Generator Loss: 3.1751
Epoch 100/100... Discriminator Loss: 0.4251... Generator Loss: 4.4085

Training Loss


In [12]:
%matplotlib inline

import matplotlib.pyplot as plt

In [13]:
fig, ax = plt.subplots()
losses = np.array(losses_1)
plt.plot(losses.T[0], label='Discriminator')
plt.plot(losses.T[1], label='Generator')
plt.title("Training Losses")
plt.legend()


Out[13]:
<matplotlib.legend.Legend at 0x7f2444cfef28>

Generate sample from training


In [14]:
def view_samples(epoch, samples):
    fig, axes = plt.subplots(figsize=(7,7), nrows=4, ncols=4, sharey=True, sharex=True)
    for ax, img in zip(axes.flatten(), samples[epoch]):
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        im = ax.imshow(img.reshape((28,28)), cmap='Greys_r')
    
    return fig, axes

In [15]:
# Load samples from generator taken while training
with open('train_samples.pkl', 'rb') as f:
    samples = pkl.load(f)

In [16]:
_ = view_samples(-1, samples)



In [17]:
rows, cols = 10, 6
fig, axes = plt.subplots(figsize=(7,12), nrows=rows, ncols=cols, sharex=True, sharey=True)

for sample, ax_row in zip(samples[::int(len(samples)/rows)], axes):
    for img, ax in zip(sample[::int(len(sample)/cols)], ax_row):
        ax.imshow(img.reshape((28,28)), cmap='Greys_r')
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)


Sampling from generator


In [18]:
saver = tf.train.Saver(var_list=g_vars_1)
with tf.Session() as sess:
    saver.restore(sess, tf.train.latest_checkpoint('checkpoints'))
    sample_z = np.random.uniform(-1, 1, size=(16, z_size))
    with tf.variable_scope('1'):
        gen_samples = sess.run(
                       generator(input_z, input_size, n_units=g_hidden_size, reuse=True, alpha=alpha),
                       feed_dict={input_z: sample_z})
view_samples(0, [gen_samples])


Out[18]:
(<matplotlib.figure.Figure at 0x7f244b6acbe0>,
 array([[<matplotlib.axes._subplots.AxesSubplot object at 0x7f244a6733c8>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7f244a1d48d0>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7f244a1a5438>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7f245500b438>],
        [<matplotlib.axes._subplots.AxesSubplot object at 0x7f2455028eb8>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7f2454f45cf8>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7f2454f2ef60>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7f2454e88400>],
        [<matplotlib.axes._subplots.AxesSubplot object at 0x7f2454e70668>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7f2454dc4be0>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7f2454dada90>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7f2454d084e0>],
        [<matplotlib.axes._subplots.AxesSubplot object at 0x7f2454cef470>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7f2454c42da0>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7f2454c33048>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x7f2454c61d30>]], dtype=object))