In [1]:
import os
os.environ['KERAS_BACKEND']='theano' # 也可以使用 tensorflow
os.environ['THEANO_FLAGS']='floatX=float32,device=cuda'

參考 WGAN https://github.com/martinarjovsky/WassersteinGAN paper 的官方實作


In [2]:
import keras.backend as K
K.set_image_data_format('channels_first')
from keras.models import Sequential, Model
from keras.layers import Conv2D, ZeroPadding2D, BatchNormalization, Input
from keras.layers import Conv2DTranspose, Reshape, Activation, Cropping2D, Flatten
from keras.layers.advanced_activations import LeakyReLU
from keras.activations import relu
from keras.initializers import RandomNormal
conv_init = RandomNormal(0, 0.02)
gamma_init = RandomNormal(1., 0.02)


Using Theano backend.
Using cuDNN version 5105 on context None
Mapped name None to device cuda: GeForce GTX 1080 (0000:01:00.0)

跟官方實做的差異是

  • 沒有使用 extra_layers (只是因為沒有用到,可以加入)
  • 圖形計算的大小有差異,如果是 $ isize = 2^n$ 是一樣的。

In [3]:
def DCGAN_D(isize, nz, nc, ndf, n_extra_layers=0):
    _ = inputs = Input(shape=(nc, isize, isize))
    _ = ZeroPadding2D(name = 'initial.padding.{0}'.format(nc))(_)
    _ = Conv2D(filters=ndf, kernel_size=4, strides=2, use_bias=False,
                        kernel_initializer = conv_init, 
                        name = 'initial.conv.{0}-{1}'.format(nc, ndf)             
                        ) (_)
    _ = LeakyReLU(alpha=0.2, name = 'initial.relu.{0}'.format(ndf))(_)
    csize, cndf = isize // 2, ndf
    while csize > 5:
        in_feat = cndf
        out_feat = cndf*2
        _ = ZeroPadding2D(name = 'pyramid.{0}.padding'.format(in_feat))(_)
        _ = Conv2D(filters=out_feat, kernel_size=4, strides=2, use_bias=False, 
                        kernel_initializer = conv_init, 
                        name = 'pyramid.{0}-{1}.conv'.format(in_feat, out_feat)
                        ) (_)        
        _ = BatchNormalization(name = 'pyramid.{0}.batchnorm'.format(out_feat), 
                                   axis=1, epsilon=1.01e-5)(_, training=1)        
        _ = LeakyReLU(alpha=0.2, name = 'pyramid.{0}.relu'.format(out_feat))(_)
        csize, cndf = csize//2, cndf*2
    _ = Conv2D(filters=1, kernel_size=csize, strides=1, use_bias=False,                        
                        name = 'final.{0}-{1}.conv'.format(cndf, 1)         
                        ) (_)
    outputs = Flatten()(_)
    return Model(inputs=inputs, outputs=outputs)

In [4]:
def DCGAN_G(isize, nz, nc, ngf, n_extra_layers=0):
    cngf= ngf//2
    tisize = isize
    while tisize > 5:
        cngf = cngf * 2
        tisize = tisize // 2
    _ = inputs = Input(shape=(nz,))
    _ = Reshape((nz, 1,1))(_)
    _ = Conv2DTranspose(filters=cngf, kernel_size=tisize, strides=1, use_bias=False,
                           kernel_initializer = conv_init, 
                           name = 'initial.{0}-{1}.convt'.format(nz, cngf))(_)
    _ = BatchNormalization(axis=1, epsilon=1.01e-5, name = 'initial.{0}.batchnorm'.format(cngf))(_, training=1)
    _ = Activation("relu", name = 'initial.{0}.relu'.format(cngf))(_)
    csize, cndf = tisize, cngf
    
    while csize < isize//2:
        in_feat = cngf
        out_feat = cngf//2
        _ = Conv2DTranspose(filters=out_feat, kernel_size=4, strides=2, use_bias=False,
                        kernel_initializer = conv_init,
                        name = 'pyramid.{0}-{1}.convt'.format(in_feat, out_feat)             
                        ) (_)
        _ = Cropping2D(cropping=1, name = 'pyramid.{0}.cropping'.format(in_feat) )(_)
        _ = BatchNormalization(axis=1, epsilon=1.01e-5, name = 'pyramid.{0}.batchnorm'.format(out_feat))(_, training=1)        
        _ = Activation("relu", name = 'pyramid.{0}.relu'.format(out_feat))(_)
        csize, cngf = csize*2, cngf//2
    _ = Conv2DTranspose(filters=nc, kernel_size=4, strides=2, use_bias=False,
                        kernel_initializer = conv_init,
                        name = 'final.{0}-{1}.convt'.format(cngf, nc)
                        )(_)
    _ = Cropping2D(cropping=1, name = 'final.{0}.cropping'.format(nc) )(_)
    outputs = Activation("tanh", name = 'final.{0}.tanh'.format(nc))(_)
    return Model(inputs=inputs, outputs=outputs)

除了圖片大小外,參數皆為 WGAN 預設的數值。


In [5]:
nc = 3
nz = 100
ngf = 64
ndf = 64
n_extra_layers = 0
Diters = 5

imageSize = 32
batchSize = 64
lrD = 0.0003 
lrG = 0.0003 
clamp_lower, clamp_upper = -0.01, 0.01

顯示一下模型


In [6]:
netD = DCGAN_D(imageSize, nz, nc, ndf, n_extra_layers)
netD.summary()


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 3, 32, 32)         0         
_________________________________________________________________
initial.padding.3 (ZeroPaddi (None, 3, 34, 34)         0         
_________________________________________________________________
initial.conv.3-64 (Conv2D)   (None, 64, 16, 16)        3072      
_________________________________________________________________
initial.relu.64 (LeakyReLU)  (None, 64, 16, 16)        0         
_________________________________________________________________
pyramid.64.padding (ZeroPadd (None, 64, 18, 18)        0         
_________________________________________________________________
pyramid.64-128.conv (Conv2D) (None, 128, 8, 8)         131072    
_________________________________________________________________
pyramid.128.batchnorm (Batch (None, 128, 8, 8)         512       
_________________________________________________________________
pyramid.128.relu (LeakyReLU) (None, 128, 8, 8)         0         
_________________________________________________________________
pyramid.128.padding (ZeroPad (None, 128, 10, 10)       0         
_________________________________________________________________
pyramid.128-256.conv (Conv2D (None, 256, 4, 4)         524288    
_________________________________________________________________
pyramid.256.batchnorm (Batch (None, 256, 4, 4)         1024      
_________________________________________________________________
pyramid.256.relu (LeakyReLU) (None, 256, 4, 4)         0         
_________________________________________________________________
final.256-1.conv (Conv2D)    (None, 1, 1, 1)           4096      
_________________________________________________________________
flatten_1 (Flatten)          (None, 1)                 0         
=================================================================
Total params: 664,064
Trainable params: 663,296
Non-trainable params: 768
_________________________________________________________________

In [7]:
netG = DCGAN_G(imageSize, nz, nc, ngf, n_extra_layers)
netG.summary()


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         (None, 100)               0         
_________________________________________________________________
reshape_1 (Reshape)          (None, 100, 1, 1)         0         
_________________________________________________________________
initial.100-256.convt (Conv2 (None, 256, 4, 4)         409600    
_________________________________________________________________
initial.256.batchnorm (Batch (None, 256, 4, 4)         1024      
_________________________________________________________________
initial.256.relu (Activation (None, 256, 4, 4)         0         
_________________________________________________________________
pyramid.256-128.convt (Conv2 (None, 128, 10, 10)       524288    
_________________________________________________________________
pyramid.256.cropping (Croppi (None, 128, 8, 8)         0         
_________________________________________________________________
pyramid.128.batchnorm (Batch (None, 128, 8, 8)         512       
_________________________________________________________________
pyramid.128.relu (Activation (None, 128, 8, 8)         0         
_________________________________________________________________
pyramid.128-64.convt (Conv2D (None, 64, 18, 18)        131072    
_________________________________________________________________
pyramid.128.cropping (Croppi (None, 64, 16, 16)        0         
_________________________________________________________________
pyramid.64.batchnorm (BatchN (None, 64, 16, 16)        256       
_________________________________________________________________
pyramid.64.relu (Activation) (None, 64, 16, 16)        0         
_________________________________________________________________
final.64-3.convt (Conv2DTran (None, 3, 34, 34)         3072      
_________________________________________________________________
final.3.cropping (Cropping2D (None, 3, 32, 32)         0         
_________________________________________________________________
final.3.tanh (Activation)    (None, 3, 32, 32)         0         
=================================================================
Total params: 1,069,824
Trainable params: 1,068,928
Non-trainable params: 896
_________________________________________________________________

In [8]:
from keras.optimizers import RMSprop, SGD, Adam

這是 netD 的 weight clipping


In [9]:
clamp_updates = [K.update(v, K.clip(v, clamp_lower, clamp_upper))
                          for v in netD.trainable_weights]
netD_clamp = K.function([],[], clamp_updates)

下面是訓練 netD 用的。 一部分吃真正的圖形,一部分吃隨機生成的圖形

然後計算 Wasserstein distance, 並且設定 rmsprop 訓練函數


In [10]:
netD_real_input = Input(shape=(nc, imageSize, imageSize))
noisev = Input(shape=(nz,))

loss_real = K.mean(netD(netD_real_input))
loss_fake = K.mean(netD(netG(noisev)))
loss = loss_fake - loss_real # 照 paper 方向,官方實做相反
training_updates = RMSprop(lr=lrD).get_updates(netD.trainable_weights,[], loss)
netD_train = K.function([netD_real_input, noisev],
                        [loss_real, loss_fake],    
                        training_updates)

下面是訓練 netG 使用,用 netG 生出圖片,算出對應的 loss


In [11]:
loss = -loss_fake # 照 paper 方向,官方實做相反
training_updates = RMSprop(lr=lrG).get_updates(netG.trainable_weights,[], loss)
netG_train = K.function([noisev], [loss], training_updates)

一樣下載 cifar10 檔案


In [12]:
from PIL import Image
import numpy as np
import tarfile

# 下載 dataset
url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
import os
import urllib
from urllib.request import urlretrieve
def reporthook(a,b,c):
    print("\rdownloading: %5.1f%%"%(a*b*100.0/c), end="")
tar_gz = "cifar-10-python.tar.gz"
if not os.path.isfile(tar_gz):
        print('Downloading data from %s' % url)
        urlretrieve(url, tar_gz, reporthook=reporthook)

import pickle
train_X=[]
train_y=[]
tar_gz = "cifar-10-python.tar.gz"
with tarfile.open(tar_gz) as tarf:
    for i in range(1, 6):
        dataset = "cifar-10-batches-py/data_batch_%d"%i
        print("load",dataset)
        with tarf.extractfile(dataset) as f:
            result = pickle.load(f, encoding='latin1')
        train_X.extend( result['data'].reshape(-1,3,32,32)/255*2-1)
        train_y.extend(result['labels'])
    train_X=np.float32(train_X)
    train_y=np.int32(train_y)
    dataset = "cifar-10-batches-py/test_batch"
    print("load",dataset)
    with tarf.extractfile(dataset) as f:
        result = pickle.load(f, encoding='latin1')
        test_X=np.float32(result['data'].reshape(-1,3,32,32)/255*2-1)
        test_y=np.int32(result['labels'])


load cifar-10-batches-py/data_batch_1
load cifar-10-batches-py/data_batch_2
load cifar-10-batches-py/data_batch_3
load cifar-10-batches-py/data_batch_4
load cifar-10-batches-py/data_batch_5
load cifar-10-batches-py/test_batch

讓輸入的資料多一點,加入 test_X,左右鏡射


In [13]:
train_X = np.concatenate([train_X, test_X])
train_X = np.concatenate([train_X[:,:,:,::-1], train_X])

一樣的工具函數


In [14]:
from IPython.display import display
def showX(X, rows=1):
    assert X.shape[0]%rows == 0
    int_X = ( (X+1)/2*255).clip(0,255).astype('uint8')
    # N*3072 -> N*3*32*32 -> 32 * 32N * 3
    int_X = np.moveaxis(int_X.reshape(-1,3,32,32), 1, 3)
    int_X = int_X.reshape(rows, -1, 32, 32,3).swapaxes(1,2).reshape(rows*32,-1, 3)
    display(Image.fromarray(int_X))
# 訓練資料, X 的前 20 筆
showX(train_X[:20])
print(train_y[:20])
name_array = np.array("飛機、汽車、鳥、貓、鹿、狗、青蛙、馬、船、卡車".split('、'))
print(name_array[train_y[:20]])


[6 9 9 4 1 1 2 7 8 3 4 7 7 2 9 9 9 3 2 6]
['青蛙' '卡車' '卡車' '鹿' '汽車' '汽車' '鳥' '馬' '船' '貓' '鹿' '馬' '馬' '鳥' '卡車' '卡車'
 '卡車' '貓' '鳥' '青蛙']

In [15]:
# 用來查看成果的亂數
fixed_noise = np.random.normal(size=(batchSize, nz)).astype('float32')

In [16]:
import time
t0 = time.time()
niter = 30
gen_iterations = 0
targetD = np.float32([2]*batchSize+[-2]*batchSize)[:, None]
targetG = np.ones(batchSize, dtype=np.float32)[:, None]
for epoch in range(niter):
    i = 0
    #  每個 epoch 洗牌一下
    np.random.shuffle(train_X)
    batches = train_X.shape[0]//batchSize
    while i < batches:
        if gen_iterations < 25 or gen_iterations %500 == 0:
            _Diters = 100
        else:
            _Diters = Diters
        j = 0
        while j < _Diters and i < batches:
            j+=1
            netD_clamp([])
            real_data = train_X[i*batchSize:(i+1)*batchSize]
            i+=1
            # 這裡與官方有差異,在這裡生成時 netG 的 batchNormalization 是用 training 的行為
            noise = np.random.normal(size=(batchSize, nz))        
            errD_real, errD_fake  = netD_train([real_data, noise])
            errD = errD_real - errD_fake
        noise = np.random.normal(size=(batchSize, nz))        
        # 這裡與官方相同, netD 和 netG 的 batchNormalization  都是是用 training 的行為
        errG, = netG_train([noise])
        gen_iterations+=1        
        if gen_iterations%500==0:
            print('[%d/%d][%d/%d][%d] Loss_D: %f Loss_G: %f Loss_D_real: %f Loss_D_fake %f'
            % (epoch, niter, i, batches, gen_iterations,errD, errG, errD_real, errD_fake), time.time()-t0)
        if gen_iterations%500 == 0:            
            # 這裡與官方有差異,在這裡生成時 netG 的 batchNormalization 是用 training 的行為
            fake = netG.predict(fixed_noise)
            showX(fake, 4)


[2/30][1100/1875][500] Loss_D: 0.246933 Loss_G: 0.196673 Loss_D_real: 0.164480 Loss_D_fake -0.082453 64.31558012962341
[3/30][1820/1875][1000] Loss_D: 0.154147 Loss_G: 0.168771 Loss_D_real: 0.057514 Loss_D_fake -0.096633 102.50099086761475
[5/30][620/1875][1500] Loss_D: 0.142626 Loss_G: 0.123585 Loss_D_real: 0.121615 Loss_D_fake -0.021011 138.7068555355072
[6/30][1340/1875][2000] Loss_D: 0.147766 Loss_G: 0.113830 Loss_D_real: 0.120745 Loss_D_fake -0.027021 174.4229485988617
[8/30][185/1875][2500] Loss_D: 0.109531 Loss_G: 0.122349 Loss_D_real: 0.061902 Loss_D_fake -0.047629 210.20018982887268
[9/30][905/1875][3000] Loss_D: 0.067302 Loss_G: 0.034699 Loss_D_real: -0.000889 Loss_D_fake -0.068191 246.41217875480652
[10/30][1625/1875][3500] Loss_D: 0.057093 Loss_G: 0.007500 Loss_D_real: 0.019356 Loss_D_fake -0.037737 281.89103627204895
[12/30][470/1875][4000] Loss_D: 0.079177 Loss_G: -0.004505 Loss_D_real: 0.007529 Loss_D_fake -0.071648 318.7459247112274
[13/30][1190/1875][4500] Loss_D: 0.019580 Loss_G: 0.016645 Loss_D_real: 0.002973 Loss_D_fake -0.016608 355.18883538246155
[15/30][35/1875][5000] Loss_D: 0.013220 Loss_G: 0.026995 Loss_D_real: -0.014300 Loss_D_fake -0.027521 393.0301911830902
[16/30][755/1875][5500] Loss_D: -0.000608 Loss_G: 0.004390 Loss_D_real: 0.001323 Loss_D_fake 0.001931 428.9458329677582
[17/30][1475/1875][6000] Loss_D: 0.001484 Loss_G: 0.016446 Loss_D_real: -0.016459 Loss_D_fake -0.017943 464.33614134788513
[19/30][320/1875][6500] Loss_D: 0.002608 Loss_G: 0.041098 Loss_D_real: -0.026738 Loss_D_fake -0.029346 500.28098368644714
[20/30][1040/1875][7000] Loss_D: 0.055256 Loss_G: 0.081144 Loss_D_real: -0.024519 Loss_D_fake -0.079775 535.6275231838226
[21/30][1760/1875][7500] Loss_D: 0.006230 Loss_G: 0.012354 Loss_D_real: -0.013869 Loss_D_fake -0.020099 571.3686983585358
[23/30][605/1875][8000] Loss_D: -0.001072 Loss_G: 0.014652 Loss_D_real: -0.015864 Loss_D_fake -0.014793 607.6018855571747
[24/30][1325/1875][8500] Loss_D: 0.002822 Loss_G: 0.010071 Loss_D_real: 0.000153 Loss_D_fake -0.002669 642.927836894989
[26/30][170/1875][9000] Loss_D: 0.002695 Loss_G: 0.015505 Loss_D_real: -0.006187 Loss_D_fake -0.008882 678.9238963127136
[27/30][890/1875][9500] Loss_D: 0.041745 Loss_G: 0.085019 Loss_D_real: -0.003950 Loss_D_fake -0.045696 714.308536529541
[28/30][1610/1875][10000] Loss_D: -0.004787 Loss_G: 0.012075 Loss_D_real: -0.015661 Loss_D_fake -0.010874 749.6831800937653