In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
In [2]:
from tensorflow.examples.tutorials.mnist import input_data
# Loading Mnist with one-hot encoded labels
mnist = input_data.read_data_sets("./data/MNIST_data/", one_hot = True)
In [3]:
# Plotting the image
cmap = 'Greys')
In [4]:
# Creating a generator network
# The input is a distribution (here it is 100 points)
def generator(z, reuse = None):
with tf.variable_scope('gen', reuse = reuse):
hidden1 = tf.layers.dense(inputs = z,
units = 128)
# Simulating Leaky Relu
# It will be added in Tensorflow
alpha = 0.01
hidden1 = tf.maximum(alpha * hidden1, hidden1)
hidden2 = tf.layers.dense(inputs = hidden1,
units = 128)
hidden2 = tf.maximum(alpha * hidden2, hidden2)
output = tf.layers.dense(hidden2,
units = 784,
activation = tf.nn.tanh)
return output
In [5]:
# Creating a discriminator network
def discriminator(X, reuse = None):
with tf.variable_scope('dis', reuse = reuse):
hidden1 = tf.layers.dense(inputs = X,
units = 128)
# Leaky Relu
alpha = 0.01
hidden1 = tf.maximum(alpha * hidden1,hidden1)
hidden2 = tf.layers.dense(inputs = hidden1,
units = 128)
hidden2 = tf.maximum(alpha * hidden2, hidden2)
logits = tf.layers.dense(hidden2,
units = 1)
output = tf.sigmoid(logits)
return output, logits
In [6]:
real_images = tf.placeholder(tf.float32,
shape = [None, 784])
z = tf.placeholder(tf.float32,
shape = [None, 100])
In [7]:
G = generator(z)
In [8]:
D_output_real, D_logits_real = discriminator(real_images)
In [9]:
D_output_fake, D_logits_fake = discriminator(G, reuse = True)
In [10]:
def loss_func(logits_in, labels_in):
return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = logits_in,
labels = labels_in))
In [11]:
# Encoding real values as ones with 0.9 smoothing
D_real_loss = loss_func(D_logits_real, tf.ones_like(D_logits_real) * (0.9))
In [12]:
# Encoding values as zeros
D_fake_loss = loss_func(D_logits_fake, tf.zeros_like(D_logits_real))
In [13]:
# Discriminator loss
D_loss = D_real_loss + D_fake_loss
In [14]:
# Generator loss -> how successful the generator is at fooling the discriminator
G_loss = loss_func(D_logits_fake, tf.ones_like(D_logits_fake))
In [15]:
learning_rate = 0.001
In [16]:
# All trainable variables
tvars = tf.trainable_variables()
# Gather all variables for the discriminator
d_vars = [var for var in tvars if 'dis' in]
# Gather all variables for the generator
g_vars = [var for var in tvars if 'gen' in]
print([ for v in d_vars])
print([ for v in g_vars])
In [17]:
D_trainer = tf.train.AdamOptimizer(learning_rate).minimize(D_loss, var_list = d_vars)
G_trainer = tf.train.AdamOptimizer(learning_rate).minimize(G_loss, var_list = g_vars)
In [18]:
batch_size = 100
epochs = 200
init = tf.global_variables_initializer()
saver = tf.train.Saver(var_list = g_vars)
In [19]:
# Save a sample per epoch
samples = []
In [20]:
with tf.Session() as sess:
# Recall an epoch is an entire run through the training data
for e in range(epochs):
# // indicates classic division
num_batches = mnist.train.num_examples // batch_size
for i in range(num_batches):
# Grab batch of images
batch = mnist.train.next_batch(batch_size)
# Get images, reshape and rescale to pass to D
batch_images = batch[0].reshape((batch_size, 784))
batch_images = batch_images * 2 - 1
# Z (random latent noise data for Generator)
# -1 to 1 because of tanh activation
batch_z = np.random.uniform(-1, 1, size = (batch_size, 100))
# Run optimizers
# Train the discriminator on the real images
# and the fake images generated by the generator with random latent noise z.
_ =, feed_dict = {real_images: batch_images,
z: batch_z})
# Train the generator
_ =, feed_dict = {z: batch_z})
print("Currently on Epoch {} of {} total...".format(e + 1, epochs))
# Sample from generator as we're training for viewing afterwards
sample_z = np.random.uniform(-1, 1, size = (1, 100))
gen_sample =,
reuse = True),
feed_dict={z: sample_z})
samples.append(gen_sample), './models/500_epoch_model.ckpt')
In [21]:
saver = tf.train.Saver(var_list = g_vars)
new_samples = []
with tf.Session() as sess:
for x in range(5):
sample_z = np.random.uniform(-1, 1,
size = (1, 100))
gen_sample =,
reuse = True),
feed_dict = {z: sample_z})
In [22]:
plt.imshow(new_samples[-1].reshape(28, 28),
cmap = 'Greys')