In [1]:
import tensorflow as tf
import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt
import sys
import glob
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
In [2]:
from tensorflow.python.platform import flags
import argparse
tf.app.flags.FLAGS = flags._FlagValues()
tf.app.flags._global_parser = argparse.ArgumentParser()
tf.app.flags.DEFINE_integer(flag_name="epochs", docstring="number of training epoches", default_value=10)
tf.app.flags.DEFINE_integer(flag_name="crop_height", docstring="image cropping height", default_value=500)
tf.app.flags.DEFINE_integer(flag_name="crop_width", docstring="image cropping width", default_value=500)
tf.app.flags.DEFINE_integer(flag_name="target_height", docstring="image resize height", default_value=64)
tf.app.flags.DEFINE_integer(flag_name="target_width", docstring="image resize width", default_value=64)
tf.app.flags.DEFINE_integer(flag_name="batch_size", docstring="image batchsize", default_value=128)
tf.app.flags.DEFINE_float(flag_name="learning_rate", docstring="learning rate", default_value=1e-3)
tf.app.flags.DEFINE_boolean(flag_name="is_training", docstring="whether the model is at training stage", default_value=True)
In [3]:
class DataProcess(object):
def __init__(self, src_dir):
self.src_dir = src_dir
self._process()
def _process(self):
# Following code is according to:
# https://stackoverflow.com/questions/33849617/how-do-i-convert-a-directory-of-jpeg-images-to-tfrecords-file-in-tensorflow
input_files = tf.train.string_input_producer(glob.glob(os.path.join(self.src_dir,"*")), num_epochs = tf.app.flags.FLAGS.epochs)
reader = tf.WholeFileReader()
key, value = reader.read(input_files)
self.labels = key
self.imgs = tf.random_crop(tf.image.decode_jpeg(value), (tf.app.flags.FLAGS.crop_height, tf.app.flags.FLAGS.crop_width, 3))
min_after_dequeue = 1000
capacity = min_after_dequeue + 3 * tf.app.flags.FLAGS.batch_size
self.label_batch, self.img_batch = tf.train.shuffle_batch([self.labels, self.imgs], \
batch_size=tf.app.flags.FLAGS.batch_size, capacity=capacity, min_after_dequeue=min_after_dequeue)
# Resize The image to target size
self.img_batch = tf.image.resize_area(self.img_batch, (tf.app.flags.FLAGS.target_height, tf.app.flags.FLAGS.target_width))
def get_data(self):
return self.label_batch, self.img_batch
In [111]:
dp = DataProcess("../dataset/102flowers/jpg/")
dp._process()
In [116]:
with tf.Session() as s:
s.run(tf.global_variables_initializer())
s.run(tf.local_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
try:
while True:
labels, imgs = s.run(dp.get_data(), feed_dict = {})
#print(np.min(imgs))
plt.imshow(imgs[0,:,:,:]/256.0)
plt.show()
break
except tf.errors.OutOfRangeError as e:
print("fetch data ended")
coord.request_stop()
coord.join(threads)
In [4]:
class Generator(object):
"""
Generator in GAN, used to generate "fake" images
Original paper: https://arxiv.org/pdf/1511.06434.pdf
"""
def __init__(self):
self._build_graph()
def _build_graph(self):
with tf.variable_scope("generator") as scope:
print("### Print Generator Intermediate Parameter")
self.prior = tf.placeholder(dtype=tf.float32, shape=(None, 100), name="prior_gen")
self.is_training = tf.placeholder(dtype=tf.bool, shape = (), name="training_flag")
prior_proj = tf.contrib.layers.fully_connected(inputs=self.prior, num_outputs=4*4*1024,
activation_fn=None, scope="prior_projection")
prior_proj = tf.contrib.layers.batch_norm(inputs=prior_proj, center=True, scale=True, activation_fn=tf.nn.relu,
is_training= self.is_training, scope="bn0")
conv0 = tf.reshape(prior_proj, (-1, 4, 4, 1024))
conv1 = tf.contrib.layers.convolution2d_transpose(inputs=conv0, num_outputs=512, activation_fn=None,
kernel_size=(5,5), stride=(2,2), padding="SAME",scope="deconv1")
conv1 = tf.contrib.layers.batch_norm(inputs=conv1, center=True, scale=True, activation_fn=tf.nn.relu,
is_training= self.is_training, scope="bn1")
print(conv1.shape)
conv2 = tf.contrib.layers.convolution2d_transpose(inputs=conv1, num_outputs=256, activation_fn=None,
kernel_size=(5,5), stride=(2,2), padding="SAME",scope="deconv2")
conv2 = tf.contrib.layers.batch_norm(inputs=conv2, center=True, scale=True, activation_fn=tf.nn.relu,
is_training= self.is_training, scope="bn2")
print(conv2.shape)
conv3 = tf.contrib.layers.convolution2d_transpose(inputs=conv2, num_outputs=128, activation_fn=None,
kernel_size=(5,5), stride=(2,2), padding="SAME",scope="deconv3")
conv3 = tf.contrib.layers.batch_norm(inputs=conv3, center=True, scale=True, activation_fn=tf.nn.relu,
is_training= self.is_training, scope="bn3")
print(conv3.shape)
conv4 = tf.contrib.layers.convolution2d_transpose(inputs=conv3, num_outputs=3, activation_fn=None,
kernel_size=(5,5), stride=(2,2), padding="SAME",scope="deconv4")
self.gen_img = tf.nn.tanh(conv4)
self.gen_img_out = tf.cast(x=self.gen_img*128.0 + 128.0, dtype=tf.int32)
print(conv4.shape)
print("### End Print Generator Intermediate Parameter")
In [118]:
# tf.reset_default_graph()
# g = Generator()
In [10]:
class Discriminator(object):
"""
Discriminator in GAN, used to distinguish "fake" and "real" images
"""
def __init__(self, img_gen):
self._build_graph(img_gen)
def _build_graph(self, image_gen):
self.real_img = tf.placeholder(tf.float32, (None, 64, 64, 3), name="real_image")
real_img = (self.real_img - 128.0)/128.0
self.is_training = tf.placeholder(dtype=tf.bool, shape = (), name="training_flag")
self.real_judge = self._discrim(real_img)
#print(self.real_judge)
self.fake_judge = self._discrim(image_gen.gen_img, reuse = True)
#print(self.fake_judge)
def _discrim(self, input_img, reuse = None):
"""
This function will be called twice,
one for real images, and one for fake images.
"""
with tf.variable_scope("discriminator") as scope:
if reuse: scope.reuse_variables()
print("### Print Discriminator Intermediate Parameter")
print(self.is_training)
conv1 = tf.contrib.layers.convolution2d(inputs=input_img, num_outputs=64, padding="SAME",
kernel_size=(3,3), stride=(1,1), activation_fn=None, scope = "conv1")
conv1 = tf.contrib.layers.batch_norm(inputs=conv1, center=True, scale=True, activation_fn=tf.nn.leaky_relu,
is_training= self.is_training, scope="bn1")
print(conv1.shape)
conv2 = tf.contrib.layers.convolution2d(inputs=conv1, num_outputs=128, padding="SAME",
kernel_size=(3,3), stride=(2,2), activation_fn=None, scope = "conv2")
conv2 = tf.contrib.layers.batch_norm(inputs=conv2, center=True, scale=True, activation_fn=tf.nn.leaky_relu,
is_training= self.is_training, scope="bn2")
print(conv2.shape)
###
conv3 = tf.contrib.layers.convolution2d(inputs=conv2, num_outputs=128, padding="SAME",
kernel_size=(3,3), stride=(1,1), activation_fn=None, scope = "conv3")
conv3 = tf.contrib.layers.batch_norm(inputs=conv3, center=True, scale=True, activation_fn=tf.nn.leaky_relu,
is_training= self.is_training, scope="bn3")
print(conv3.shape)
conv4 = tf.contrib.layers.convolution2d(inputs=conv3, num_outputs=256, padding="SAME",
kernel_size=(3,3), stride=(2,2), activation_fn=None, scope = "conv4")
conv4 = tf.contrib.layers.batch_norm(inputs=conv4, center=True, scale=True, activation_fn=tf.nn.leaky_relu,
is_training= self.is_training, scope="bn4")
print(conv4.shape)
###
conv5 = tf.contrib.layers.convolution2d(inputs=conv4, num_outputs=256, padding="SAME",
kernel_size=(3,3), stride=(1,1), activation_fn=None, scope = "conv5")
conv5 = tf.contrib.layers.batch_norm(inputs=conv5, center=True, scale=True, activation_fn=tf.nn.leaky_relu,
is_training= self.is_training, scope="bn5")
print(conv5.shape)
conv6 = tf.contrib.layers.convolution2d(inputs=conv5, num_outputs=1024, padding="SAME",
kernel_size=(3,3), stride=(2,2), activation_fn=None, scope = "conv6")
conv6 = tf.contrib.layers.batch_norm(inputs=conv6, center=True, scale=True, activation_fn=tf.nn.leaky_relu,
is_training= self.is_training, scope="bn6")
print(conv6.shape)
conv7 = tf.contrib.layers.convolution2d(inputs=conv6, num_outputs=1024, padding="SAME",
kernel_size=(3,3), stride=(1,1), activation_fn=None, scope = "conv7")
conv7 = tf.contrib.layers.batch_norm(inputs=conv7, center=True, scale=True, activation_fn=tf.nn.leaky_relu,
is_training= self.is_training, scope="bn7")
print(conv7.shape)
conv8 = tf.contrib.layers.convolution2d(inputs=conv7, num_outputs=2048, padding="SAME",
kernel_size=(3,3), stride=(2,2), activation_fn=None, scope = "conv8")
conv8 = tf.contrib.layers.batch_norm(inputs=conv8, center=True, scale=True, activation_fn=tf.nn.leaky_relu,
is_training= self.is_training, scope="bn8")
print(conv8.shape)
conv9 = tf.contrib.layers.convolution2d(inputs=conv8, num_outputs=102, padding="SAME",
kernel_size=(1,1), stride=(1,1), activation_fn=tf.nn.leaky_relu, scope="conv9")
conv9 = tf.contrib.layers.avg_pool2d(inputs=conv9, kernel_size=(4,4), stride=(4,4), padding="SAME", scope="avg_pool9")
print(conv9.shape)
print("### End Print Discriminator Intermediate Parameter")
### no need to perform sigmoid
return tf.contrib.layers.fully_connected(inputs=tf.reshape(conv9,(-1, 102)), num_outputs=1,
activation_fn=None, scope="output_projection")
In [120]:
# tf.reset_default_graph()
# g = Generator()
# d = Discriminator(g)
In [121]:
# tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
In [122]:
# tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
In [12]:
class GANModel(object):
def __init__(self, generator, discriminator, datasrc):
self.generator = generator
self.discriminator = discriminator
self.datasrc = datasrc
self.sess = None
self.saver = tf.train.Saver()
def train(self, model_path = None):
self.sess = tf.Session()
fake_result = self.discriminator.fake_judge
real_result = self.discriminator.real_judge
fake_rate = tf.reduce_mean(tf.cast(tf.nn.sigmoid(fake_result) > 0.5, tf.float32))
loss_g = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_result , labels = tf.ones_like(fake_result)))
loss_d = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=real_result , labels = tf.ones_like(real_result)))\
+ tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_result , labels = tf.zeros_like(fake_result)))
optim_g = tf.train.AdamOptimizer(tf.app.flags.FLAGS.learning_rate).minimize(loss_g, var_list = \
tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator'))
optim_d = tf.train.AdamOptimizer(tf.app.flags.FLAGS.learning_rate).minimize(loss_d, var_list = \
tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator'))
writer = tf.summary.FileWriter("./log")
summary_g = tf.summary.scalar(name="generator_loss", tensor=loss_g)
summary_d = tf.summary.scalar(name="discriminator_loss", tensor=loss_d)
summary_fake_rate = tf.summary.scalar(name="fake_rate", tensor=fake_rate)
if model_path:
self.saver.restore(self.sess, model_path)
self.sess.run(tf.global_variables_initializer())
self.sess.run(tf.local_variables_initializer())
cnt_total = 0
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord, sess=self.sess)
try:
while True:
labels, imgs = self.sess.run(self.datasrc.get_data(), feed_dict = {})
# First train discriminator
# print(self.discriminator.is_training)
# print(self.generator.is_training)
loss_d_out, _, summary_d_out = self.sess.run([loss_d, optim_d, summary_d],
feed_dict = {
self.discriminator.real_img: imgs,
self.generator.prior: \
np.random.randn(tf.app.flags.FLAGS.batch_size, 100)*1e-2,
self.discriminator.is_training: True,
self.generator.is_training: True
})
# Then train generator
loss_g_out, _, summary_g_out = self.sess.run([loss_g, optim_g, summary_g],
feed_dict = {
self.generator.prior: \
np.random.randn(tf.app.flags.FLAGS.batch_size, 100)*1e-2,
self.discriminator.is_training: True,
self.generator.is_training: True,
})
# Then evaluate the fake ratio
fake_rate_out, summary_fake_rate_out = self.sess.run([fake_rate, summary_fake_rate],
feed_dict = {
self.generator.prior: \
np.random.randn(tf.app.flags.FLAGS.batch_size, 100)*1e-2,
self.discriminator.is_training: False,
self.generator.is_training: False,
})
cnt_total += 1
writer.add_summary(summary_d_out, cnt_total)
writer.add_summary(summary_g_out, cnt_total)
writer.add_summary(summary_fake_rate_out, cnt_total)
print("In batch %3d, Dicriminator Loss %.3f, Generator Loss %.3f, Fake Ratio %.3f\r" \
%(cnt_total, loss_d_out, loss_g_out, fake_rate_out), end="")
# Save every 100 batches
if cnt_total % 50 == 0:
self.saver.save(self.sess, "./log/model_%3d.ckpt" %(cnt_total//50))
except tf.errors.OutOfRangeError as e:
print("fetch data ended")
coord.request_stop()
coord.join(threads)
def infer_gen(self, model_path = None, n_img = 1, prior=None):
"""
After the training, now we can use generator images!
n_img: number of images, if not given any prior
prior: given priors, if None, then random generate
"""
if not self.sess:
self.sess = tf.Session()
if not model_path:
print("Invalid model path!")
sys.exit()
else:
self.saver.restore(self.sess, model_path)
if not prior:
prior = np.random.randn(n_img, 100)* 1e-2
imgs = self.sess.run(self.generator.gen_img_out, feed_dict = {self.generator.prior: prior, self.generator.is_training: False})
return imgs
def infer_dis(self):
"""
In fact, discriminator can be used to predict,
but here we will not complete the code
"""
pass
In [ ]:
if __name__ == "__main__":
tf.reset_default_graph()
dp = DataProcess("../dataset/102flowers/jpg/")
g = Generator()
d = Discriminator(g)
gan = GANModel(generator=g, discriminator=d, datasrc=dp)
gan.train()
In [ ]: