DCGAN with Labeled Faces in the Wild Dataset


In [8]:
%matplotlib inline
import matplotlib as mpl
mpl.use('Agg')

import os
import h5py
import math
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

from keras.models import Sequential
from keras.layers import Dense, Activation, Reshape, Flatten, Dropout, MaxPooling2D, GlobalAveragePooling2D
from keras.layers.normalization import BatchNormalization
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.layers.advanced_activations import LeakyReLU
from keras.optimizers import Adam

np.random.seed(7)


/Users/koichiro.mori/.pyenv/versions/anaconda3-4.2.0/lib/python3.5/site-packages/matplotlib/__init__.py:1357: UserWarning:  This call to matplotlib.use() has no effect
because the backend has already been chosen;
matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
or matplotlib.backends is imported for the first time.

  warnings.warn(_use_error_msg)

データセットをロード


In [14]:
f = h5py.File('./data/lfw.hdf5', 'r')
print(list(f.keys()))
lfw = f['lfw'][:]
# [-1, 1]に正規化
lfw = (lfw - 0.5) / 0.5
print(np.max(lfw), np.min(lfw))
print(lfw.shape)
X_train = lfw.transpose((0, 2, 3, 1))
print(X_train.shape)


['lfw']
1.0 -1.0
(13233, 3, 64, 64)
(13233, 64, 64, 3)

In [16]:
plt.imshow((X_train[1] / 2) + 0.5)


Out[16]:
<matplotlib.image.AxesImage at 0x121c40208>

DCGANモデルを作成


In [10]:
def generator_model():
    model = Sequential()
    model.add(Dense(4 * 4 * 1024, input_dim=100))
    model.add(Reshape((4, 4, 1024), input_shape=(4 * 4 * 1024, )))
    model.add(Activation('relu'))
    model.add(UpSampling2D((2, 2)))
    model.add(Conv2D(512, (5, 5), padding='same'))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(UpSampling2D((2, 2)))
    model.add(Conv2D(256, (5, 5), padding='same'))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(UpSampling2D((2, 2)))
    model.add(Conv2D(128, (5, 5), padding='same'))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(UpSampling2D((2, 2)))
    model.add(Conv2D(3, (3, 3), padding='same'))
    model.add(Activation('tanh'))
    return model

def discriminator_model():
    model = Sequential()
    model.add(Conv2D(128, (5, 5), strides=(2, 2), padding='same', input_shape=(64, 64, 3)))
    model.add(LeakyReLU(0.2))
    model.add(Dropout(0.2))
    model.add(Conv2D(256, (5, 5), strides=(2, 2), padding='same'))
    model.add(LeakyReLU(0.2))
    model.add(Dropout(0.2))
    model.add(Conv2D(512, (5, 5), strides=(2, 2), padding='same'))
    model.add(LeakyReLU(0.2))
    model.add(Conv2D(1024, (5, 5), strides=(2, 2), padding='same'))
    model.add(LeakyReLU(0.2))
    model.add(GlobalAveragePooling2D())
    model.add(Dense(1))
    model.add(Activation('sigmoid'))
    return model

model = discriminator_model()
model.summary()


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_17 (Conv2D)           (None, 32, 32, 128)       9728      
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 32, 32, 128)       0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 32, 32, 128)       0         
_________________________________________________________________
conv2d_18 (Conv2D)           (None, 16, 16, 256)       819456    
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 16, 16, 256)       0         
_________________________________________________________________
dropout_2 (Dropout)          (None, 16, 16, 256)       0         
_________________________________________________________________
conv2d_19 (Conv2D)           (None, 8, 8, 512)         3277312   
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 8, 8, 512)         0         
_________________________________________________________________
conv2d_20 (Conv2D)           (None, 4, 4, 1024)        13108224  
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU)    (None, 4, 4, 1024)        0         
_________________________________________________________________
global_average_pooling2d_1 ( (None, 1024)              0         
_________________________________________________________________
dense_6 (Dense)              (None, 1)                 1025      
_________________________________________________________________
activation_19 (Activation)   (None, 1)                 0         
=================================================================
Total params: 17,215,745
Trainable params: 17,215,745
Non-trainable params: 0
_________________________________________________________________

In [21]:
def plot_images(images, fname=None, figsize=(12, 12), rows=4):
    assert images.shape[0] == 16
    # imagesは[-1, 1]なので描画できるように[0, 1]に戻す
    images = (images / 2) + 0.5
    f = plt.figure(figsize=figsize)
    for i in range(len(images)):
        sp = f.add_subplot(rows, len(images) // rows, i + 1)
        sp.axis('off')
        plt.imshow(images[i])
    if fname is not None:
        plt.savefig(fname)

In [22]:
BATCH_SIZE = 64
NUM_EPOCH = 100
GENERATED_IMAGE_PATH = 'generated_images_lfw/'
MODEL_PATH = 'models_lfw/'

if not os.path.exists(GENERATED_IMAGE_PATH):
    os.mkdir(GENERATED_IMAGE_PATH)

if not os.path.exists(MODEL_PATH):
    os.mkdir(MODEL_PATH)

# discriminatorのみのモデル構築
discriminator = discriminator_model()
d_opt = Adam(lr=1e-5, beta_1=0.1)
discriminator.compile(loss='binary_crossentropy', optimizer=d_opt)
discriminator.summary()

# generator + discriminatorのモデル構築
# discriminatorの重みは固定(固定されるのはdcganの中のdiscriminatorのみ)
# trainableを反映させるにはcompile()が必要
# summary()表示するとわかる
discriminator.trainable = False
generator = generator_model()
# generatorが生成した画像をdiscriminatorが予測
dcgan = Sequential([generator, discriminator])
g_opt = Adam(lr=2e-4, beta_1=0.5)
dcgan.compile(loss='binary_crossentropy', optimizer=g_opt)
dcgan.summary()


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_21 (Conv2D)           (None, 32, 32, 128)       9728      
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU)    (None, 32, 32, 128)       0         
_________________________________________________________________
dropout_3 (Dropout)          (None, 32, 32, 128)       0         
_________________________________________________________________
conv2d_22 (Conv2D)           (None, 16, 16, 256)       819456    
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU)    (None, 16, 16, 256)       0         
_________________________________________________________________
dropout_4 (Dropout)          (None, 16, 16, 256)       0         
_________________________________________________________________
conv2d_23 (Conv2D)           (None, 8, 8, 512)         3277312   
_________________________________________________________________
leaky_re_lu_7 (LeakyReLU)    (None, 8, 8, 512)         0         
_________________________________________________________________
conv2d_24 (Conv2D)           (None, 4, 4, 1024)        13108224  
_________________________________________________________________
leaky_re_lu_8 (LeakyReLU)    (None, 4, 4, 1024)        0         
_________________________________________________________________
global_average_pooling2d_2 ( (None, 1024)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 1025      
_________________________________________________________________
activation_20 (Activation)   (None, 1)                 0         
=================================================================
Total params: 17,215,745
Trainable params: 17,215,745
Non-trainable params: 0
_________________________________________________________________
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
sequential_8 (Sequential)    (None, 64, 64, 3)         18865923  
_________________________________________________________________
sequential_7 (Sequential)    (None, 1)                 17215745  
=================================================================
Total params: 36,081,668
Trainable params: 18,864,131
Non-trainable params: 17,217,537
_________________________________________________________________

訓練


In [25]:
num_batches = int(X_train.shape[0] / BATCH_SIZE)
print('Number of batches:', num_batches)

d_loss_history = []
g_loss_history = []

for epoch in range(NUM_EPOCH):
    for index in range(1):
        # Generatorへの入力となるノイズベクトルをバッチサイズ分作成
        noise = np.array([np.random.uniform(-1, 1, 100) for _ in range(BATCH_SIZE)])
        
        # 本物の画像(訓練データ)
        image_batch = X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE]
        
        # 生成画像
        generated_images = generator.predict(noise, verbose=0)
        
        # discriminatorを更新
        X = np.concatenate((image_batch, generated_images))
        y = [1] * BATCH_SIZE + [0] * BATCH_SIZE
        d_loss = discriminator.train_on_batch(X, y)
        d_loss_history.append(d_loss)

        # generatorを更新
        noise = np.array([np.random.uniform(-1, 1, 100) for _ in range(BATCH_SIZE)])
        g_loss = dcgan.train_on_batch(noise, [1] * BATCH_SIZE)
        g_loss_history.append(g_loss)

        print('epoch: %d, batch: %d, g_loss: %f, d_loss: %f' % (epoch, index, g_loss, d_loss))

    # 各エポックで生成画像を出力
    print(np.min(generated_images), np.max(generated_images))
    image = plot_images(generated_images[:16], GENERATED_IMAGE_PATH + 'epoch-%04d.png' % (epoch))
    
    # モデルを保存
    generator.save('%s/generator-%03d-%.2f.h5' % (MODEL_PATH, epoch, g_loss))
    discriminator.save('%s/discriminator-%03d-%.2f.h5' % (MODEL_PATH, epoch, d_loss))


Number of batches: 103
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-25-ed08146c424d> in <module>()
     14 
     15         # 生成画像
---> 16         generated_images = generator.predict(noise, verbose=0)
     17 
     18         # discriminatorを更新

/Users/koichiro.mori/.pyenv/versions/anaconda3-4.2.0/lib/python3.5/site-packages/keras/models.py in predict(self, x, batch_size, verbose)
    900         if self.model is None:
    901             self.build()
--> 902         return self.model.predict(x, batch_size=batch_size, verbose=verbose)
    903 
    904     def predict_on_batch(self, x):

/Users/koichiro.mori/.pyenv/versions/anaconda3-4.2.0/lib/python3.5/site-packages/keras/engine/training.py in predict(self, x, batch_size, verbose)
   1583         f = self.predict_function
   1584         return self._predict_loop(f, ins,
-> 1585                                   batch_size=batch_size, verbose=verbose)
   1586 
   1587     def train_on_batch(self, x, y,

/Users/koichiro.mori/.pyenv/versions/anaconda3-4.2.0/lib/python3.5/site-packages/keras/engine/training.py in _predict_loop(self, f, ins, batch_size, verbose)
   1210                 ins_batch = _slice_arrays(ins, batch_ids)
   1211 
-> 1212             batch_outs = f(ins_batch)
   1213             if not isinstance(batch_outs, list):
   1214                 batch_outs = [batch_outs]

/Users/koichiro.mori/.pyenv/versions/anaconda3-4.2.0/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py in __call__(self, inputs)
   2227         session = get_session()
   2228         updated = session.run(self.outputs + [self.updates_op],
-> 2229                               feed_dict=feed_dict)
   2230         return updated[:len(self.outputs)]
   2231 

/Users/koichiro.mori/.pyenv/versions/anaconda3-4.2.0/lib/python3.5/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    787     try:
    788       result = self._run(None, fetches, feed_dict, options_ptr,
--> 789                          run_metadata_ptr)
    790       if run_metadata:
    791         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/Users/koichiro.mori/.pyenv/versions/anaconda3-4.2.0/lib/python3.5/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
    995     if final_fetches or final_targets:
    996       results = self._do_run(handle, final_targets, final_fetches,
--> 997                              feed_dict_string, options, run_metadata)
    998     else:
    999       results = []

/Users/koichiro.mori/.pyenv/versions/anaconda3-4.2.0/lib/python3.5/site-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1130     if handle is None:
   1131       return self._do_call(_run_fn, self._session, feed_dict, fetch_list,
-> 1132                            target_list, options, run_metadata)
   1133     else:
   1134       return self._do_call(_prun_fn, self._session, handle, feed_dict,

/Users/koichiro.mori/.pyenv/versions/anaconda3-4.2.0/lib/python3.5/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
   1137   def _do_call(self, fn, *args):
   1138     try:
-> 1139       return fn(*args)
   1140     except errors.OpError as e:
   1141       message = compat.as_text(e.message)

/Users/koichiro.mori/.pyenv/versions/anaconda3-4.2.0/lib/python3.5/site-packages/tensorflow/python/client/session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
   1119         return tf_session.TF_Run(session, options,
   1120                                  feed_dict, fetch_list, target_list,
-> 1121                                  status, run_metadata)
   1122 
   1123     def _prun_fn(session, handle, feed_dict, fetch_list):

KeyboardInterrupt: 

In [26]:
plot_images(X_train[:16], 'test.png')



In [27]:
from IPython.display import Image
Image('test.png')


Out[27]:

In [ ]: