Test Points

  • Remove trunable before compile
  • No compile for a generator
  • Using a default optimizer without changing learning parameters

In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
from importlib import reload

In [2]:
import numpy as np
import matplotlib.pyplot as plt

In [3]:
from keras import models, layers, optimizers
from keras.layers import Dense, Input, Conv1D, Reshape, Flatten
from keras.models import Model
from keras.optimizers import Adam


Using TensorFlow backend.

Define Model


In [4]:
lr = 0.0002
adam = Adam(lr=lr, beta_1=0.5)
model_compile = lambda model: model.compile(loss='binary_crossentropy', optimizer=adam, metrics=['accuracy'])

class GAN:
    def __init__(self, ni_D, nh_D, nh_G):
        D = models.Sequential()
        D.add(Dense(nh_D, activation='relu', input_shape=(ni_D,)))
        D.add(Dense(nh_D, activation='relu'))
        D.add(Dense(1, activation='sigmoid'))
        model_compile(D)
        
        G = models.Sequential() # (Batch, ni_D)
        G.add(Reshape((ni_D, 1), input_shape=(ni_D,))) # (Batch, steps=ni_D, input_dim=1)
        G.add(Conv1D(nh_G, 1)) # (Batch, ni_D, nh_G)
        G.add(Conv1D(nh_G, 1)) # (Batch, ni_D, nh_G)
        G.add(Conv1D(1, 1)) # (Batch, ni_D, 1)
        G.add(Flatten()) # (Batch, ni_D)
        model_compile(G)
        
        GD = models.Sequential()
        GD.add(G)
        GD.add(D)
        D.trainable = False
        model_compile(GD)
        D.trainable = True

        self.D, self.G, self.GD = D, G, GD
        
    def D_train_on_batch(self, Real, Gen):
        D = self.D
        X = np.concatenate([Real, Gen], axis=0)
        y = [1] * Real.shape[0] + [0] * Gen.shape[0]
        D.train_on_batch(X, y)
        
    def GD_train_on_batch(self, Z):
        GD, D = self.GD, self.D
        y = [1] * Z.shape[0]
        GD.train_on_batch(Z, y)

In [5]:
gan = GAN(ni_D=100, nh_D=50, nh_G=50)

Load Data


In [6]:
class Data:
    def __init__(self, mu, sigma, ni_D):
        self.real_sample = lambda n_batch: np.random.normal(mu, sigma, (n_batch, ni_D))
        self.in_sample = lambda n_batch: np.random.rand(n_batch, ni_D)
        # self.ni_D = ni_D

Test train

  • Testing training for G

Define Machine


In [12]:
class Machine:
    def __init__(self, n_batch=10, ni_D=100):
        self.data = Data(0, 1, ni_D)
        self.gan = GAN(ni_D=ni_D, nh_D=50, nh_G=50)
        
        self.n_batch = n_batch
        # self.ni_D = ni_D
        
    def train_D(self):
        gan = self.gan
        n_batch = self.n_batch
        data = self.data
        
        # Real data
        Real = data.real_sample(n_batch) # (n_batch, ni_D)
        # print(Real.shape)
        # Generated data
        Z = data.in_sample(n_batch) # (n_batch, ni_D)
        Gen = gan.G.predict(Z) # (n_batch, ni_D)
        # print(Gen.shape)
        
        gan.D.trainable = True
        gan.D_train_on_batch(Real, Gen)
    
    def train_GD(self):
        gan = self.gan
        n_batch = self.n_batch
        data = self.data
        # Seed data for data generation
        Z = data.in_sample(n_batch)        
        
        gan.D.trainable = False
        gan.GD_train_on_batch(Z)
    
    def train_each(self):
        self.train_D()
        self.train_GD()
    
    def train(self, epochs):
        for epoch in range(epochs):
            self.train_each()
            
    def test(self, n_test):
        """
        generate a new image
        """
        gan = self.gan
        data = self.data
        Z = data.in_sample(n_test)
        Gen = gan.G.predict(Z)
        return Gen, Z
        
    def show_hist(self, Real, Gen, Z):
        plt.hist(Real.reshape(-1), histtype='step', label='Real')
        plt.hist(Gen.reshape(-1), histtype='step', label='Generated')
        plt.hist(Z.reshape(-1), histtype='step', label='Input')
        plt.legend(loc=0)
        
    def test_and_show(self, n_test):
        data = self.data
        Gen, Z = self.test(n_test)
        Real = data.real_sample(n_test)
        self.show_hist(Real, Gen, Z)
            
    def run(self, epochs, n_test):
        """
        train GAN and show the results
        for showing, the original and the artificial results will be compared
        """
        self.train(epochs)
        self.test_and_show(n_test)
        
    def run_loop(self, n_iter=100, epochs_each=1000, n_test=1000):
        for ii in range(n_iter):
            print('Stage', ii)
            machine.run(epochs_each, n_test)
            plt.show()

In [13]:
machine = Machine(n_batch=10, ni_D=1000)

In [14]:
machine.run_loop(100, 1000, 1000)


Stage 0
Stage 1
Stage 2
Stage 3
Stage 4
Stage 5
Stage 6
Stage 7
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-14-eeb50f7da125> in <module>()
----> 1 machine.run_loop(100, 1000, 1000)

<ipython-input-12-24a986888551> in run_loop(self, n_iter, epochs_each, n_test)
     74         for ii in range(n_iter):
     75             print('Stage', ii)
---> 76             machine.run(epochs_each, n_test)
     77             plt.show()

<ipython-input-12-24a986888551> in run(self, epochs, n_test)
     68         for showing, the original and the artificial results will be compared
     69         """
---> 70         self.train(epochs)
     71         self.test_and_show(n_test)
     72 

<ipython-input-12-24a986888551> in train(self, epochs)
     39     def train(self, epochs):
     40         for epoch in range(epochs):
---> 41             self.train_each()
     42 
     43     def test(self, n_test):

<ipython-input-12-24a986888551> in train_each(self)
     35     def train_each(self):
     36         self.train_D()
---> 37         self.train_GD()
     38 
     39     def train(self, epochs):

<ipython-input-12-24a986888551> in train_GD(self)
     31 
     32         gan.D.trainable = False
---> 33         gan.GD_train_on_batch(Z)
     34 
     35     def train_each(self):

<ipython-input-4-f459f637ed77> in GD_train_on_batch(self, Z)
     38         GD, D = self.GD, self.D
     39         y = [1] * Z.shape[0]
---> 40         GD.train_on_batch(Z, y)

/home/sjkim/anaconda3/lib/python3.6/site-packages/keras/models.py in train_on_batch(self, x, y, class_weight, sample_weight)
    931         return self.model.train_on_batch(x, y,
    932                                          sample_weight=sample_weight,
--> 933                                          class_weight=class_weight)
    934 
    935     def test_on_batch(self, x, y,

/home/sjkim/anaconda3/lib/python3.6/site-packages/keras/engine/training.py in train_on_batch(self, x, y, sample_weight, class_weight)
   1618             ins = x + y + sample_weights
   1619         self._make_train_function()
-> 1620         outputs = self.train_function(ins)
   1621         if len(outputs) == 1:
   1622             return outputs[0]

/home/sjkim/anaconda3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py in __call__(self, inputs)
   2071         session = get_session()
   2072         updated = session.run(self.outputs + [self.updates_op],
-> 2073                               feed_dict=feed_dict)
   2074         return updated[:len(self.outputs)]
   2075 

/home/sjkim/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    776     try:
    777       result = self._run(None, fetches, feed_dict, options_ptr,
--> 778                          run_metadata_ptr)
    779       if run_metadata:
    780         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/home/sjkim/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
    980     if final_fetches or final_targets:
    981       results = self._do_run(handle, final_targets, final_fetches,
--> 982                              feed_dict_string, options, run_metadata)
    983     else:
    984       results = []

/home/sjkim/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1030     if handle is None:
   1031       return self._do_call(_run_fn, self._session, feed_dict, fetch_list,
-> 1032                            target_list, options, run_metadata)
   1033     else:
   1034       return self._do_call(_prun_fn, self._session, handle, feed_dict,

/home/sjkim/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
   1037   def _do_call(self, fn, *args):
   1038     try:
-> 1039       return fn(*args)
   1040     except errors.OpError as e:
   1041       message = compat.as_text(e.message)

/home/sjkim/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
   1019         return tf_session.TF_Run(session, options,
   1020                                  feed_dict, fetch_list, target_list,
-> 1021                                  status, run_metadata)
   1022 
   1023     def _prun_fn(session, handle, feed_dict, fetch_list):

KeyboardInterrupt: 

In [ ]: