使用 GAN 生成手写数据


In [29]:
import pickle as pkl

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

1.导入数据


In [30]:
mnist = input_data.read_data_sets('MNIST_data')


Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz

2.定义图(Graph)

2.1 输入定义


In [31]:
def model_inputs(real_dim, z_dim):
    inputs_real = tf.placeholder(tf.float32, (None, real_dim), name='input_real') 
    inputs_z = tf.placeholder(tf.float32, (None, z_dim), name='input_z')
    
    return inputs_real, inputs_z

2.2 生成器和辨别器


In [32]:
# 生成器
def generator(z, out_dim, n_units=128, reuse=False):
    with tf.variable_scope('generator', reuse=reuse):
        # z => relu => tanh
        h1 = tf.layers.dense(z, n_units, activation=None)

        h1 = tf.maximum(0.0, h1)

        logits = tf.layers.dense(h1, out_dim, activation=None)
        out = tf.tanh(logits)

        return out

In [33]:
# 辨别器
def discriminator(z, n_units=128, reuse=False, alpha=.01):
    with tf.variable_scope('discriminator', reuse=reuse):
        # z(or inputs_real) => leakyrelu => sigmoid
        h1 = tf.layers.dense(z, n_units, activation=None)

        h1 = tf.maximum(alpha * h1, h1)

        logits = tf.layers.dense(h1, 1, activation=None)
        out = tf.sigmoid(logits)

        return out, logits

2.3超参数


In [34]:
input_size = 784
z_size = 100
g_hidden_size = 128
d_hidden_size = 128
# Leak factor for leaky ReLU
alpha = 0.01
# Smoothing 
smooth = 0.1

2.4 构建网络


In [35]:
input_real, input_z = model_inputs(input_size, z_size)

g_model = generator(input_z, input_size, n_units=g_hidden_size)

d_model_real, d_logits_real = discriminator(input_real, n_units=d_hidden_size, alpha=alpha)
d_model_fake, d_logits_fake = discriminator(g_model, reuse=True, n_units=d_hidden_size, alpha=alpha)


---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-35-ee3e3bdf675f> in <module>()
      1 input_real, input_z = model_inputs(input_size, z_size)
      2 
----> 3 g_model = generator(input_z, input_size, n_units=g_hidden_size)
      4 
      5 d_model_real, d_logits_real = discriminator(input_real, n_units=d_hidden_size, alpha=alpha)

<ipython-input-32-cfaf5689af28> in generator(z, out_dim, n_units, reuse)
      3     with tf.variable_scope('generator', reuse=reuse):
      4         # z => relu => tanh
----> 5         h1 = tf.layers.dense(z, n_units, activation=None)
      6 
      7         h1 = tf.maximum(0.0, h1)

~/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/layers/core.py in dense(inputs, units, activation, use_bias, kernel_initializer, bias_initializer, kernel_regularizer, bias_regularizer, activity_regularizer, trainable, name, reuse)
    216                 _scope=name,
    217                 _reuse=reuse)
--> 218   return layer.apply(inputs)
    219 
    220 

~/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/layers/base.py in apply(self, inputs, **kwargs)
    318       Output tensor(s).
    319     """
--> 320     return self.__call__(inputs, **kwargs)
    321 
    322 

~/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/layers/base.py in __call__(self, inputs, **kwargs)
    284           input_shapes = [x.get_shape() for x in input_list]
    285           if len(input_shapes) == 1:
--> 286             self.build(input_shapes[0])
    287           else:
    288             self.build(input_shapes)

~/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/layers/core.py in build(self, input_shape)
    121                                   regularizer=self.kernel_regularizer,
    122                                   dtype=self.dtype,
--> 123                                   trainable=True)
    124     if self.use_bias:
    125       self.bias = vs.get_variable('bias',

~/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/ops/variable_scope.py in get_variable(name, shape, dtype, initializer, regularizer, trainable, collections, caching_device, partitioner, validate_shape, use_resource, custom_getter)
   1047       collections=collections, caching_device=caching_device,
   1048       partitioner=partitioner, validate_shape=validate_shape,
-> 1049       use_resource=use_resource, custom_getter=custom_getter)
   1050 get_variable_or_local_docstring = (
   1051     """%s

~/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/ops/variable_scope.py in get_variable(self, var_store, name, shape, dtype, initializer, regularizer, trainable, collections, caching_device, partitioner, validate_shape, use_resource, custom_getter)
    946           collections=collections, caching_device=caching_device,
    947           partitioner=partitioner, validate_shape=validate_shape,
--> 948           use_resource=use_resource, custom_getter=custom_getter)
    949 
    950   def _get_partitioned_variable(self,

~/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/ops/variable_scope.py in get_variable(self, name, shape, dtype, initializer, regularizer, reuse, trainable, collections, caching_device, partitioner, validate_shape, use_resource, custom_getter)
    347           reuse=reuse, trainable=trainable, collections=collections,
    348           caching_device=caching_device, partitioner=partitioner,
--> 349           validate_shape=validate_shape, use_resource=use_resource)
    350     else:
    351       return _true_getter(

~/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/layers/base.py in variable_getter(getter, name, shape, dtype, initializer, regularizer, trainable, **kwargs)
    273           name, shape, initializer=initializer, regularizer=regularizer,
    274           dtype=dtype, trainable=trainable,
--> 275           variable_getter=functools.partial(getter, **kwargs))
    276 
    277     # Build (if necessary) and call the layer, inside a variable scope.

~/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/layers/base.py in _add_variable(self, name, shape, dtype, initializer, regularizer, trainable, variable_getter)
    226                                initializer=initializer,
    227                                dtype=dtype,
--> 228                                trainable=trainable and self.trainable)
    229     # TODO(sguada) fix name = variable.op.name
    230     if variable in existing_variables:

~/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/ops/variable_scope.py in _true_getter(name, shape, dtype, initializer, regularizer, reuse, trainable, collections, caching_device, partitioner, validate_shape, use_resource)
    339           trainable=trainable, collections=collections,
    340           caching_device=caching_device, validate_shape=validate_shape,
--> 341           use_resource=use_resource)
    342 
    343     if custom_getter is not None:

~/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/ops/variable_scope.py in _get_single_variable(self, name, shape, dtype, initializer, regularizer, partition_info, reuse, trainable, collections, caching_device, validate_shape, use_resource)
    651                          " Did you mean to set reuse=True in VarScope? "
    652                          "Originally defined at:\n\n%s" % (
--> 653                              name, "".join(traceback.format_list(tb))))
    654       found_var = self._vars[name]
    655       if not shape.is_compatible_with(found_var.get_shape()):

ValueError: Variable generator/dense/kernel already exists, disallowed. Did you mean to set reuse=True in VarScope? Originally defined at:

  File "<ipython-input-4-cfaf5689af28>", line 5, in generator
    h1 = tf.layers.dense(z, n_units, activation=None)
  File "<ipython-input-7-ee3e3bdf675f>", line 3, in <module>
    g_model = generator(input_z, input_size, n_units=g_hidden_size)
  File "/home/quoniammm/anaconda3/envs/tf/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2862, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)

2.5 辨别器和分类器损失


In [27]:
# Calculate losses
d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real, labels=tf.ones_like(d_logits_real) * (1 - smooth)))
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, labels=tf.zeros_like(d_logits_real)))

d_loss = d_loss_real + d_loss_fake

g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, labels=tf.ones_like(d_logits_fake)))

2.6 优化器定义


In [28]:
# Optimizers
learning_rate = 0.002

# Get the trainable_variables, split into G and D parts
t_vars = tf.trainable_variables()
g_vars = [var for var in t_vars if var.name.startswith('generator')]
d_vars = [var for var in t_vars if var.name.startswith('discriminator')]

d_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(d_loss, var_list=d_vars)
g_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(g_loss, var_list=g_vars)


---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-28-ce26135605fb> in <module>()
      8 
      9 d_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(d_loss, var_list=d_vars)
---> 10 g_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(g_loss, var_list=g_vars)

~/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/training/optimizer.py in minimize(self, loss, global_step, var_list, gate_gradients, aggregation_method, colocate_gradients_with_ops, name, grad_loss)
    323 
    324     return self.apply_gradients(grads_and_vars, global_step=global_step,
--> 325                                 name=name)
    326 
    327   def compute_gradients(self, loss, var_list=None,

~/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/training/optimizer.py in apply_gradients(self, grads_and_vars, global_step, name)
    444                        ([str(v) for _, _, v in converted_grads_and_vars],))
    445     with ops.control_dependencies(None):
--> 446       self._create_slots([_get_variable_for(v) for v in var_list])
    447     update_ops = []
    448     with ops.name_scope(name, self._name) as name:

~/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/training/adam.py in _create_slots(self, var_list)
    120     # Create slots for the first and second moments.
    121     for v in var_list:
--> 122       self._zeros_slot(v, "m", self._name)
    123       self._zeros_slot(v, "v", self._name)
    124 

~/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/training/optimizer.py in _zeros_slot(self, var, slot_name, op_name)
    764     named_slots = self._slot_dict(slot_name)
    765     if _var_key(var) not in named_slots:
--> 766       named_slots[_var_key(var)] = slot_creator.create_zeros_slot(var, op_name)
    767     return named_slots[_var_key(var)]

~/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/training/slot_creator.py in create_zeros_slot(primary, name, dtype, colocate_with_primary)
    172     return create_slot_with_initializer(
    173         primary, initializer, slot_shape, dtype, name,
--> 174         colocate_with_primary=colocate_with_primary)
    175   else:
    176     val = array_ops.zeros(slot_shape, dtype=dtype)

~/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/training/slot_creator.py in create_slot_with_initializer(primary, initializer, shape, dtype, name, colocate_with_primary)
    144       with ops.colocate_with(primary):
    145         return _create_slot_var(primary, initializer, "", validate_shape, shape,
--> 146                                 dtype)
    147     else:
    148       return _create_slot_var(primary, initializer, "", validate_shape, shape,

~/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/training/slot_creator.py in _create_slot_var(primary, val, scope, validate_shape, shape, dtype)
     64       use_resource=_is_resource(primary),
     65       shape=shape, dtype=dtype,
---> 66       validate_shape=validate_shape)
     67   variable_scope.get_variable_scope().set_partitioner(current_partitioner)
     68 

~/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/ops/variable_scope.py in get_variable(name, shape, dtype, initializer, regularizer, trainable, collections, caching_device, partitioner, validate_shape, use_resource, custom_getter)
   1047       collections=collections, caching_device=caching_device,
   1048       partitioner=partitioner, validate_shape=validate_shape,
-> 1049       use_resource=use_resource, custom_getter=custom_getter)
   1050 get_variable_or_local_docstring = (
   1051     """%s

~/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/ops/variable_scope.py in get_variable(self, var_store, name, shape, dtype, initializer, regularizer, trainable, collections, caching_device, partitioner, validate_shape, use_resource, custom_getter)
    946           collections=collections, caching_device=caching_device,
    947           partitioner=partitioner, validate_shape=validate_shape,
--> 948           use_resource=use_resource, custom_getter=custom_getter)
    949 
    950   def _get_partitioned_variable(self,

~/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/ops/variable_scope.py in get_variable(self, name, shape, dtype, initializer, regularizer, reuse, trainable, collections, caching_device, partitioner, validate_shape, use_resource, custom_getter)
    354           reuse=reuse, trainable=trainable, collections=collections,
    355           caching_device=caching_device, partitioner=partitioner,
--> 356           validate_shape=validate_shape, use_resource=use_resource)
    357 
    358   def _get_partitioned_variable(

~/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/ops/variable_scope.py in _true_getter(name, shape, dtype, initializer, regularizer, reuse, trainable, collections, caching_device, partitioner, validate_shape, use_resource)
    339           trainable=trainable, collections=collections,
    340           caching_device=caching_device, validate_shape=validate_shape,
--> 341           use_resource=use_resource)
    342 
    343     if custom_getter is not None:

~/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/ops/variable_scope.py in _get_single_variable(self, name, shape, dtype, initializer, regularizer, partition_info, reuse, trainable, collections, caching_device, validate_shape, use_resource)
    651                          " Did you mean to set reuse=True in VarScope? "
    652                          "Originally defined at:\n\n%s" % (
--> 653                              name, "".join(traceback.format_list(tb))))
    654       found_var = self._vars[name]
    655       if not shape.is_compatible_with(found_var.get_shape()):

ValueError: Variable generator/dense/kernel/Adam/ already exists, disallowed. Did you mean to set reuse=True in VarScope? Originally defined at:

  File "<ipython-input-9-ce26135605fb>", line 10, in <module>
    g_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(g_loss, var_list=g_vars)
  File "/home/quoniammm/anaconda3/envs/tf/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2862, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/home/quoniammm/anaconda3/envs/tf/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2802, in run_ast_nodes
    if self.run_code(code, result):

3.在 Session 中计算图并训练


In [10]:
batch_size = 100
epochs = 100
samples = []
losses = []

saver = tf.train.Saver(var_list=g_vars)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for e in range(epochs):
        for ii in range(mnist.train.num_examples//batch_size):
            batch = mnist.train.next_batch(batch_size)
                
            # 得到图片并 reshape 
            batch_images = batch[0].reshape((batch_size, 784))
            # ???
            batch_images = batch_images * 2 - 1
            # 生成 fake 图片
            batch_z = np.random.uniform(-1, 1, size=(batch_size, z_size))
            # 运行优化器函数
            _ = sess.run(d_train_opt, feed_dict={
                input_real: batch_images,
                input_z: batch_z
            })
            _ = sess.run(g_train_opt, feed_dict={
                input_z: batch_z
            })
            
        train_loss_d = sess.run(d_loss, {input_z: batch_z, input_real: batch_images})
        train_loss_g = sess.run(g_loss, {input_z: batch_z})
        
        print("Epoch {}/{}...".format(e+1, epochs),
              "Discriminator Loss: {:.4f}...".format(train_loss_d),
              "Generator Loss: {:.4f}".format(train_loss_g))
        
        
        losses.append((train_loss_d, train_loss_g))
        
        # 查看结果
        sample_z = np.random.uniform(-1, 1, size=(16, z_size))
        gen_samples = sess.run(
                       generator(input_z, input_size, n_units=g_hidden_size, reuse=True),
                       feed_dict={input_z: sample_z})
        samples.append(gen_samples)
        
        saver.save(sess, './checkpoints/generator.ckpt')
        
    
# 保存训练 generator samples
with open('train_samples.pkl', 'wb') as f:
    pkl.dump(samples, f)


Epoch 1/100... Discriminator Loss: 0.3544... Generator Loss: 4.4785
Epoch 2/100... Discriminator Loss: 0.7165... Generator Loss: 2.0073
Epoch 3/100... Discriminator Loss: 0.6644... Generator Loss: 4.6733
Epoch 4/100... Discriminator Loss: 2.4130... Generator Loss: 3.2340
Epoch 5/100... Discriminator Loss: 2.2782... Generator Loss: 2.6394
Epoch 6/100... Discriminator Loss: 1.5873... Generator Loss: 1.4744
Epoch 7/100... Discriminator Loss: 1.3938... Generator Loss: 1.4747
Epoch 8/100... Discriminator Loss: 0.8521... Generator Loss: 1.7710
Epoch 9/100... Discriminator Loss: 1.2340... Generator Loss: 1.2663
Epoch 10/100... Discriminator Loss: 1.5735... Generator Loss: 0.7784
Epoch 11/100... Discriminator Loss: 1.0534... Generator Loss: 3.2012
Epoch 12/100... Discriminator Loss: 1.9542... Generator Loss: 1.4279
Epoch 13/100... Discriminator Loss: 1.9767... Generator Loss: 2.1712
Epoch 14/100... Discriminator Loss: 0.9567... Generator Loss: 2.5153
Epoch 15/100... Discriminator Loss: 3.8910... Generator Loss: 0.6774
Epoch 16/100... Discriminator Loss: 1.7681... Generator Loss: 1.5896
Epoch 17/100... Discriminator Loss: 1.1829... Generator Loss: 1.5891
Epoch 18/100... Discriminator Loss: 0.9922... Generator Loss: 1.9825
Epoch 19/100... Discriminator Loss: 1.2778... Generator Loss: 1.4737
Epoch 20/100... Discriminator Loss: 1.8206... Generator Loss: 2.8005
Epoch 21/100... Discriminator Loss: 0.9369... Generator Loss: 2.3725
Epoch 22/100... Discriminator Loss: 0.8816... Generator Loss: 2.4938
Epoch 23/100... Discriminator Loss: 0.8987... Generator Loss: 2.8349
Epoch 24/100... Discriminator Loss: 0.9187... Generator Loss: 3.3629
Epoch 25/100... Discriminator Loss: 0.9319... Generator Loss: 1.9052
Epoch 26/100... Discriminator Loss: 0.7478... Generator Loss: 2.4325
Epoch 27/100... Discriminator Loss: 0.8986... Generator Loss: 1.9025
Epoch 28/100... Discriminator Loss: 0.6721... Generator Loss: 3.5047
Epoch 29/100... Discriminator Loss: 0.8731... Generator Loss: 1.8997
Epoch 30/100... Discriminator Loss: 0.6333... Generator Loss: 2.7971
Epoch 31/100... Discriminator Loss: 0.7967... Generator Loss: 2.4373
Epoch 32/100... Discriminator Loss: 1.0417... Generator Loss: 1.5576
Epoch 33/100... Discriminator Loss: 0.6886... Generator Loss: 2.2210
Epoch 34/100... Discriminator Loss: 0.5970... Generator Loss: 2.6517
Epoch 35/100... Discriminator Loss: 0.6639... Generator Loss: 2.9884
Epoch 36/100... Discriminator Loss: 0.6464... Generator Loss: 2.6809
Epoch 37/100... Discriminator Loss: 0.7245... Generator Loss: 2.7015
Epoch 38/100... Discriminator Loss: 0.8438... Generator Loss: 2.2406
Epoch 39/100... Discriminator Loss: 0.6605... Generator Loss: 3.4752
Epoch 40/100... Discriminator Loss: 0.8480... Generator Loss: 2.7831
Epoch 41/100... Discriminator Loss: 0.6852... Generator Loss: 2.7074
Epoch 42/100... Discriminator Loss: 0.8118... Generator Loss: 2.6830
Epoch 43/100... Discriminator Loss: 1.0298... Generator Loss: 2.8299
Epoch 44/100... Discriminator Loss: 0.8266... Generator Loss: 2.5185
Epoch 45/100... Discriminator Loss: 0.6413... Generator Loss: 2.4401
Epoch 46/100... Discriminator Loss: 0.7249... Generator Loss: 4.0528
Epoch 47/100... Discriminator Loss: 0.8862... Generator Loss: 1.9536
Epoch 48/100... Discriminator Loss: 0.9206... Generator Loss: 2.5433
Epoch 49/100... Discriminator Loss: 0.8777... Generator Loss: 2.9774
Epoch 50/100... Discriminator Loss: 0.7546... Generator Loss: 3.2317
Epoch 51/100... Discriminator Loss: 0.6575... Generator Loss: 2.9100
Epoch 52/100... Discriminator Loss: 0.7172... Generator Loss: 2.9132
Epoch 53/100... Discriminator Loss: 0.8816... Generator Loss: 2.2341
Epoch 54/100... Discriminator Loss: 0.8433... Generator Loss: 2.2138
Epoch 55/100... Discriminator Loss: 0.8684... Generator Loss: 2.4706
Epoch 56/100... Discriminator Loss: 0.7558... Generator Loss: 2.1506
Epoch 57/100... Discriminator Loss: 0.8512... Generator Loss: 2.5543
Epoch 58/100... Discriminator Loss: 0.9612... Generator Loss: 2.2605
Epoch 59/100... Discriminator Loss: 0.6947... Generator Loss: 2.5379
Epoch 60/100... Discriminator Loss: 0.7418... Generator Loss: 3.3872
Epoch 61/100... Discriminator Loss: 0.6932... Generator Loss: 2.2419
Epoch 62/100... Discriminator Loss: 0.7928... Generator Loss: 2.2659
Epoch 63/100... Discriminator Loss: 0.7963... Generator Loss: 2.8519
Epoch 64/100... Discriminator Loss: 0.7620... Generator Loss: 2.7281
Epoch 65/100... Discriminator Loss: 0.9336... Generator Loss: 1.8140
Epoch 66/100... Discriminator Loss: 0.8493... Generator Loss: 2.4214
Epoch 67/100... Discriminator Loss: 0.8742... Generator Loss: 2.1912
Epoch 68/100... Discriminator Loss: 0.9805... Generator Loss: 1.8804
Epoch 69/100... Discriminator Loss: 0.6433... Generator Loss: 2.9091
Epoch 70/100... Discriminator Loss: 0.7438... Generator Loss: 2.8397
Epoch 71/100... Discriminator Loss: 0.7164... Generator Loss: 2.3444
Epoch 72/100... Discriminator Loss: 0.9627... Generator Loss: 2.5519
Epoch 73/100... Discriminator Loss: 0.8996... Generator Loss: 2.0538
Epoch 74/100... Discriminator Loss: 0.7520... Generator Loss: 2.1769
Epoch 75/100... Discriminator Loss: 1.1362... Generator Loss: 1.3790
Epoch 76/100... Discriminator Loss: 0.7726... Generator Loss: 2.2651
Epoch 77/100... Discriminator Loss: 0.8428... Generator Loss: 2.0681
Epoch 78/100... Discriminator Loss: 1.0779... Generator Loss: 1.7973
Epoch 79/100... Discriminator Loss: 1.1857... Generator Loss: 2.0028
Epoch 80/100... Discriminator Loss: 1.0319... Generator Loss: 1.6755
Epoch 81/100... Discriminator Loss: 0.8388... Generator Loss: 2.0686
Epoch 82/100... Discriminator Loss: 0.9276... Generator Loss: 1.8606
Epoch 83/100... Discriminator Loss: 0.6687... Generator Loss: 2.8781
Epoch 84/100... Discriminator Loss: 0.8562... Generator Loss: 2.2028
Epoch 85/100... Discriminator Loss: 0.8220... Generator Loss: 2.5882
Epoch 86/100... Discriminator Loss: 0.7642... Generator Loss: 2.2159
Epoch 87/100... Discriminator Loss: 0.8680... Generator Loss: 2.0097
Epoch 88/100... Discriminator Loss: 0.8786... Generator Loss: 2.2172
Epoch 89/100... Discriminator Loss: 0.9259... Generator Loss: 2.0287
Epoch 90/100... Discriminator Loss: 0.8486... Generator Loss: 3.9869
Epoch 91/100... Discriminator Loss: 0.9319... Generator Loss: 1.9108
Epoch 92/100... Discriminator Loss: 0.8624... Generator Loss: 1.6981
Epoch 93/100... Discriminator Loss: 0.8774... Generator Loss: 2.0098
Epoch 94/100... Discriminator Loss: 0.8198... Generator Loss: 2.1095
Epoch 95/100... Discriminator Loss: 0.8303... Generator Loss: 1.9160
Epoch 96/100... Discriminator Loss: 1.0276... Generator Loss: 1.8844
Epoch 97/100... Discriminator Loss: 1.3044... Generator Loss: 1.8199
Epoch 98/100... Discriminator Loss: 1.0317... Generator Loss: 1.6128
Epoch 99/100... Discriminator Loss: 0.9975... Generator Loss: 1.9223
Epoch 100/100... Discriminator Loss: 0.8204... Generator Loss: 2.6106

4.结果分析


In [12]:
# 绘制 LOSS 曲线
%matplotlib inline

fig, ax = plt.subplots()
losses = np.array(losses)
plt.plot(losses.T[0], label='Discriminator')
plt.plot(losses.T[1], label='Generator')
plt.title("Training Losses")
plt.legend()


Out[12]:
<matplotlib.legend.Legend at 0x7f752335b9e8>

5.生成图片


In [13]:
def view_samples(epoch, samples):
    fig, axes = plt.subplots(figsize=(7,7), nrows=4, ncols=4, sharey=True, sharex=True)
    for ax, img in zip(axes.flatten(), samples[epoch]):
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        im = ax.imshow(img.reshape((28,28)), cmap='Greys_r')
    
    return fig, axes

# 载入训练时的样本
with open('train_samples.pkl', 'rb') as f:
    samples = pkl.load(f)

In [17]:
_ = view_samples(-1, samples)



In [18]:
rows, cols = 10, 6
fig, axes = plt.subplots(figsize=(7,12), nrows=rows, ncols=cols, sharex=True, sharey=True)

for sample, ax_row in zip(samples[::int(len(samples)/rows)], axes):
    for img, ax in zip(sample[::int(len(sample)/cols)], ax_row):
        ax.imshow(img.reshape((28,28)), cmap='Greys_r')
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)



In [19]:
saver = tf.train.Saver(var_list=g_vars)
with tf.Session() as sess:
    saver.restore(sess, tf.train.latest_checkpoint('checkpoints'))
    sample_z = np.random.uniform(-1, 1, size=(16, z_size))
    gen_samples = sess.run(
                   generator(input_z, input_size, n_units=g_hidden_size, reuse=True, alpha=alpha),
                   feed_dict={input_z: sample_z})
_ = view_samples(0, [gen_samples])


INFO:tensorflow:Restoring parameters from None
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
TypeError: expected bytes, NoneType found

The above exception was the direct cause of the following exception:

SystemError                               Traceback (most recent call last)
<ipython-input-19-63c3a8e21010> in <module>()
      1 saver = tf.train.Saver(var_list=g_vars)
      2 with tf.Session() as sess:
----> 3     saver.restore(sess, tf.train.latest_checkpoint('checkpoints'))
      4     sample_z = np.random.uniform(-1, 1, size=(16, z_size))
      5     gen_samples = sess.run(

~/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/training/saver.py in restore(self, sess, save_path)
   1455     logging.info("Restoring parameters from %s", save_path)
   1456     sess.run(self.saver_def.restore_op_name,
-> 1457              {self.saver_def.filename_tensor_name: save_path})
   1458 
   1459   @staticmethod

~/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    776     try:
    777       result = self._run(None, fetches, feed_dict, options_ptr,
--> 778                          run_metadata_ptr)
    779       if run_metadata:
    780         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

~/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
    980     if final_fetches or final_targets:
    981       results = self._do_run(handle, final_targets, final_fetches,
--> 982                              feed_dict_string, options, run_metadata)
    983     else:
    984       results = []

~/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1030     if handle is None:
   1031       return self._do_call(_run_fn, self._session, feed_dict, fetch_list,
-> 1032                            target_list, options, run_metadata)
   1033     else:
   1034       return self._do_call(_prun_fn, self._session, handle, feed_dict,

~/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
   1037   def _do_call(self, fn, *args):
   1038     try:
-> 1039       return fn(*args)
   1040     except errors.OpError as e:
   1041       message = compat.as_text(e.message)

~/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
   1019         return tf_session.TF_Run(session, options,
   1020                                  feed_dict, fetch_list, target_list,
-> 1021                                  status, run_metadata)
   1022 
   1023     def _prun_fn(session, handle, feed_dict, fetch_list):

SystemError: <built-in function TF_Run> returned a result with an error set

In [ ]: