In [4]:
%matplotlib inline

In [13]:
from keras.models import Sequential
from keras.layers import Dense, Reshape
from keras.layers.core import Activation, Flatten
from keras.layers.normalization import BatchNormalization
from keras.layers.convolutional import UpSampling2D, Conv2D, MaxPooling2D
from keras.optimizers import SGD
from keras.datasets import mnist
import numpy as np
from PIL import Image
import argparse
import math

In [ ]:
def generator_model():
    layers = [
        Dense(1024, input_dim=100),
        Activation('tanh'),
        Dense(7 * 7 * 128),
        BatchNormalization(),
        Activation('tanh'),
        Reshape((7, 7, 128), input_shape=(7 * 7 * 128,)),
        UpSampling2D(size=(2, 2)),  # 14x14
        Conv2D(64, (5, 5), padding='same'),
        Activation('tanh'),
        UpSampling2D(size=(2, 2)),  # 28x28
        Conv2D(1, (5, 5), padding='same'),
        Activation('tanh')        
    ]
    model = Sequential(layers)
    return model
  • UpSamplingで出力の特徴マップのサイズを拡大しながらMNISTの28x28に近づける
  • Conv2Dでチャネル数(特徴マップ数)は減らしていく

In [20]:
gen = generator_model()
gen.summary()


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_15 (Dense)             (None, 1024)              103424    
_________________________________________________________________
activation_23 (Activation)   (None, 1024)              0         
_________________________________________________________________
dense_16 (Dense)             (None, 6272)              6428800   
_________________________________________________________________
batch_normalization_8 (Batch (None, 6272)              25088     
_________________________________________________________________
activation_24 (Activation)   (None, 6272)              0         
_________________________________________________________________
reshape_8 (Reshape)          (None, 7, 7, 128)         0         
_________________________________________________________________
up_sampling2d_12 (UpSampling (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 14, 14, 64)        204864    
_________________________________________________________________
activation_25 (Activation)   (None, 14, 14, 64)        0         
_________________________________________________________________
up_sampling2d_13 (UpSampling (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_10 (Conv2D)           (None, 28, 28, 1)         1601      
_________________________________________________________________
activation_26 (Activation)   (None, 28, 28, 1)         0         
=================================================================
Total params: 6,763,777
Trainable params: 6,751,233
Non-trainable params: 12,544
_________________________________________________________________

In [22]:
def discriminator_model():
    layers = [
        Conv2D(64, (5, 5), padding='same', input_shape=(28, 28, 1)),
        Activation('tanh'),
        MaxPooling2D(pool_size=(2, 2)),
        Conv2D(128, (5, 5)),
        Activation('tanh'),
        MaxPooling2D(pool_size=(2, 2)),
        Flatten(),
        Dense(1024),
        Activation('tanh'),
        Dense(1),
        Activation('sigmoid')
    ]
    model = Sequential(layers)
    return model
  • MaxPoolingで特徴マップサイズを縮小
  • チャンネル数は増やしていく
  • 通常のCNNと同じ構成

In [23]:
dis = discriminator_model()
dis.summary()


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_11 (Conv2D)           (None, 28, 28, 64)        1664      
_________________________________________________________________
activation_27 (Activation)   (None, 28, 28, 64)        0         
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 14, 14, 64)        0         
_________________________________________________________________
conv2d_12 (Conv2D)           (None, 10, 10, 128)       204928    
_________________________________________________________________
activation_28 (Activation)   (None, 10, 10, 128)       0         
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 5, 5, 128)         0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 3200)              0         
_________________________________________________________________
dense_17 (Dense)             (None, 1024)              3277824   
_________________________________________________________________
activation_29 (Activation)   (None, 1024)              0         
_________________________________________________________________
dense_18 (Dense)             (None, 1)                 1025      
_________________________________________________________________
activation_30 (Activation)   (None, 1)                 0         
=================================================================
Total params: 3,485,441
Trainable params: 3,485,441
Non-trainable params: 0
_________________________________________________________________

In [25]:
def generator_containing_discriminator(generator, discriminator):
    model = Sequential()
    model.add(generator)
    discriminator.trainable = False
    model.add(discriminator)
    return model

In [28]:
model = generator_containing_discriminator(gen, dis)
model.summary()


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
sequential_5 (Sequential)    (None, 28, 28, 1)         6763777   
_________________________________________________________________
sequential_6 (Sequential)    (None, 1)                 3485441   
=================================================================
Total params: 10,249,218
Trainable params: 6,751,233
Non-trainable params: 3,497,985
_________________________________________________________________

In [29]:
def combine_images(generated_images):
    num = generated_images.shape[0]
    width = int(math.sqrt(num))
    height = int(math.ceil(float(num) / width))
    shape = generated_images.shape[1:3]
    image = np.zeros((height * shape[0], width * shape[1]),
                     dtype=generated_images.dtype)
    for index, img in enumerate(generated_images):
        i = int(index / width)
        j = index % width
        image[i * shape[0]:(i + 1) *shape[0], j * shape[1]:(j + 1) * shape[1]] = img[:, :, 0]
    return image

In [79]:
def train(batch_size):
    # load MNIST data
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
    X_train = (X_train.astype(np.float32) - 127.5) / 127.5
    X_train = X_train.reshape(X_train.shape[0], X_train.shape[1], X_train.shape[2], 1)

    discriminator = discriminator_model()
    generator = generator_model()
    discriminator_on_generator = generator_containing_discriminator(generator, discriminator)
    
    d_optim = SGD(lr=0.0005, momentum=0.9, nesterov=True)
    g_optim = SGD(lr=0.0005, momentum=0.9, nesterov=True)

    generator.compile(loss='binary_crossentropy', optimizer='SGD')
    
    # generator: trainable, discriminator: freeze
    discriminator_on_generator.summary()
    discriminator_on_generator.compile(loss='binary_crossentropy', optimizer=g_optim)

    # discriminator: trainable
    discriminator.trainable = True
    discriminator.summary()
    discriminator.compile(loss='binary_crossentropy', optimizer=d_optim)

    noise = np.zeros((batch_size, 100))
    
    for epoch in range(20):
        print('epoch:', epoch)
        num_batches = int(X_train.shape[0] / batch_size)
        print('number of batches', num_batches)
        for index in range(num_batches):
            for i in range(batch_size):
                noise[i, :] = np.random.uniform(-1, 1, 100)
            image_batch = X_train[index * batch_size:(index + 1) * batch_size]

            generated_images = generator.predict(noise, verbose=0)
            if index % 20 == 0:
                image = combine_images(generated_images)
                image = image * 127.5 + 127.5
                Image.fromarray(image.astype(np.uint8)).save('%d-%d.png' % (epoch, index))
            
            X = np.concatenate((image_batch, generated_images))
            y = [1] * batch_size + [0] * batch_size
            
#            before_weights = discriminator.layers[0].get_weights()[0]
            d_loss = discriminator.train_on_batch(X, y)
#            after_weights = discriminator.layers[0].get_weights()[0]
#            print(np.array_equal(before_weights, after_weights))

#            a1 = discriminator.layers[0].get_weights()[0]

            for i in range(batch_size):
                noise[i, :] = np.random.uniform(-1, 1, 100)

#            before_weights = discriminator_on_generator.layers[1].get_weights()[0]
            g_loss = discriminator_on_generator.train_on_batch(noise, [1] * batch_size)
#            after_weights = discriminator_on_generator.layers[1].get_weights()[0]
#            print(np.array_equal(before_weights, after_weights))
           
#            a2 = discriminator_on_generator.layers[1].layers[0].get_weights()[0]
#            print(np.array_equal(a1.shape, a2.shape))
            
            print('epoch: %d, batch: %d, g_loss: %f, d_loss: %f' % (epoch, index, g_loss, d_loss))
  • generator.compile(loss='binary_crossentropy', optimizer='SGD')は必要?
  • compile()していなくてもpredict()はできるのでは?
  • generatorの訓練はdiscriminator_on_generatorが行う
  • discriminator_on_generatorはモデル作成時にdiscriminator.trainable=Falseでcompileしているので訓練時に設定は不要では?
  • 更新前と更新後のdiscriminator(layers[1])の重みを比較するとフリーズされていることがわかる
  • trainableを設定しただけではsummary()には反映されるがcompile()しないと反映されない
  • a1とa2を比較するとdiscriminatorの重みが共有されていることがわかる

In [80]:
train(128)


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
sequential_70 (Sequential)   (None, 28, 28, 1)         6763777   
_________________________________________________________________
sequential_69 (Sequential)   (None, 1)                 3485441   
=================================================================
Total params: 10,249,218
Trainable params: 6,751,233
Non-trainable params: 3,497,985
_________________________________________________________________
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_93 (Conv2D)           (None, 28, 28, 64)        1664      
_________________________________________________________________
activation_191 (Activation)  (None, 28, 28, 64)        0         
_________________________________________________________________
max_pooling2d_43 (MaxPooling (None, 14, 14, 64)        0         
_________________________________________________________________
conv2d_94 (Conv2D)           (None, 10, 10, 128)       204928    
_________________________________________________________________
activation_192 (Activation)  (None, 10, 10, 128)       0         
_________________________________________________________________
max_pooling2d_44 (MaxPooling (None, 5, 5, 128)         0         
_________________________________________________________________
flatten_22 (Flatten)         (None, 3200)              0         
_________________________________________________________________
dense_99 (Dense)             (None, 1024)              3277824   
_________________________________________________________________
activation_193 (Activation)  (None, 1024)              0         
_________________________________________________________________
dense_100 (Dense)            (None, 1)                 1025      
_________________________________________________________________
activation_194 (Activation)  (None, 1)                 0         
=================================================================
Total params: 3,485,441
Trainable params: 3,485,441
Non-trainable params: 0
_________________________________________________________________
epoch: 0
number of batches 468
True
epoch: 0, batch: 0, g_loss: 0.708722, d_loss: 0.724036
True
epoch: 0, batch: 1, g_loss: 0.705510, d_loss: 0.722274
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-80-989ebd93b810> in <module>()
----> 1 train(128)

<ipython-input-79-ea5b4353c394> in train(BATCH_SIZE)
     53 
     54 #            before_weights = discriminator_on_generator.layers[1].get_weights()[0]
---> 55             g_loss = discriminator_on_generator.train_on_batch(noise, [1] * BATCH_SIZE)
     56 #            after_weights = discriminator_on_generator.layers[1].get_weights()[0]
     57 #            print(np.array_equal(before_weights, after_weights))

/Users/koichiro.mori/.pyenv/versions/anaconda3-4.2.0/lib/python3.5/site-packages/keras/models.py in train_on_batch(self, x, y, class_weight, sample_weight)
    942         return self.model.train_on_batch(x, y,
    943                                          sample_weight=sample_weight,
--> 944                                          class_weight=class_weight)
    945 
    946     def test_on_batch(self, x, y,

/Users/koichiro.mori/.pyenv/versions/anaconda3-4.2.0/lib/python3.5/site-packages/keras/engine/training.py in train_on_batch(self, x, y, sample_weight, class_weight)
   1631             ins = x + y + sample_weights
   1632         self._make_train_function()
-> 1633         outputs = self.train_function(ins)
   1634         if len(outputs) == 1:
   1635             return outputs[0]

/Users/koichiro.mori/.pyenv/versions/anaconda3-4.2.0/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py in __call__(self, inputs)
   2227         session = get_session()
   2228         updated = session.run(self.outputs + [self.updates_op],
-> 2229                               feed_dict=feed_dict)
   2230         return updated[:len(self.outputs)]
   2231 

/Users/koichiro.mori/.pyenv/versions/anaconda3-4.2.0/lib/python3.5/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    787     try:
    788       result = self._run(None, fetches, feed_dict, options_ptr,
--> 789                          run_metadata_ptr)
    790       if run_metadata:
    791         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/Users/koichiro.mori/.pyenv/versions/anaconda3-4.2.0/lib/python3.5/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
    995     if final_fetches or final_targets:
    996       results = self._do_run(handle, final_targets, final_fetches,
--> 997                              feed_dict_string, options, run_metadata)
    998     else:
    999       results = []

/Users/koichiro.mori/.pyenv/versions/anaconda3-4.2.0/lib/python3.5/site-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1130     if handle is None:
   1131       return self._do_call(_run_fn, self._session, feed_dict, fetch_list,
-> 1132                            target_list, options, run_metadata)
   1133     else:
   1134       return self._do_call(_prun_fn, self._session, handle, feed_dict,

/Users/koichiro.mori/.pyenv/versions/anaconda3-4.2.0/lib/python3.5/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
   1137   def _do_call(self, fn, *args):
   1138     try:
-> 1139       return fn(*args)
   1140     except errors.OpError as e:
   1141       message = compat.as_text(e.message)

/Users/koichiro.mori/.pyenv/versions/anaconda3-4.2.0/lib/python3.5/site-packages/tensorflow/python/client/session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
   1119         return tf_session.TF_Run(session, options,
   1120                                  feed_dict, fetch_list, target_list,
-> 1121                                  status, run_metadata)
   1122 
   1123     def _prun_fn(session, handle, feed_dict, fetch_list):

KeyboardInterrupt: 

In [ ]: