In [1]:
import GAN.models as models
import GAN.cms_datasets as cms
import GAN.plotting as plotting
import GAN.preprocessing as preprocessing
import GAN.base as base


Using TensorFlow backend.

In [2]:
import GAN.utils as utils

reload(utils)

class Parameters(utils.Parameters):
        
    g_opts=utils.param(dict(name="G_64x5",kernel_sizes=[32,64,128,256],do_weight_reg=1e-2))
    d_opts=utils.param(dict(name="D_256x5",kernel_sizes=[256]*5,
                        clip_weights=2.e-2,activation=None)) # weight clipping and no actication
    dm_opts=utils.param(dict(optimizer="RMSprop",opt_kwargs=dict(lr=0.0001)))#, decay=6e-6)))
    am_opts=utils.param(dict(optimizer="RMSprop",opt_kwargs=dict(lr=0.0001)))#, decay=6e-6)))
    
    epochs=utils.param(200)
    batch_size=utils.param(4096)
    plot_every=utils.param(5)
    
    # frac_data=utils.param(10)
    
    loss = "wgan_loss" # use WGAN loss 
    gan_targets = 'gan_targets_hinge' # hinge targets are 1, -1 instead of 0, 1
    schedule = [0]*2+[1] # number of critic iterations per generators iteration
    
    monitor_dir = utils.param('log')
    
class MyApp(utils.MyApp):
    classes = utils.List([Parameters])

notebook_parameters = Parameters(MyApp()).get_params()

globals().update(notebook_parameters)
DM_OPTS.update( {"loss":LOSS} )
AM_OPTS.update( {"loss":LOSS} )
notebook_parameters


Out[2]:
{'AM_OPTS': {'loss': 'wgan_loss',
  'opt_kwargs': {'lr': 0.0001},
  'optimizer': 'RMSprop'},
 'BATCH_SIZE': 4096,
 'DM_OPTS': {'loss': 'wgan_loss',
  'opt_kwargs': {'lr': 0.0001},
  'optimizer': 'RMSprop'},
 'D_OPTS': {'activation': None,
  'clip_weights': 0.02,
  'kernel_sizes': [256, 256, 256, 256, 256],
  'name': 'D_256x5'},
 'EPOCHS': 200,
 'GAN_TARGETS': 'gan_targets_hinge',
 'G_OPTS': {'do_weight_reg': 0.0002,
  'kernel_sizes': [32, 64, 128, 256],
  'name': 'G_64x5'},
 'LOSS': 'wgan_loss',
 'PLOT_EVERY': 5,
 'SCHEDULE': [0, 0, 1]}

In [3]:
import GAN.toy_datasets as toys

In [4]:
reload(toys)
c_train,c_test,x_train,x_test,z_train,z_test =  toys.three_peaks_conditional_cube(2000000)


(2000000, 1, 1) (2000000, 1, 1) (2000000, 1, 1)
(2000000, 1, 1) (2000000, 1, 1) (2000000, 1, 1)

In [5]:
plt.hexbin( c_train.ravel(), x_train.ravel() )
plt.show()
plotting.plot_hists(x_train.ravel(),z_train.ravel())#,range=[-4,10])



In [6]:
xz_shape = x_train.shape[1:]
c_shape = c_train.shape[1:]

In [7]:
xz_shape


Out[7]:
(1, 1)

In [ ]:


In [ ]:


In [8]:
reload(models)

gan = models.MyFFGAN( xz_shape, xz_shape, c_shape=c_shape,
                     g_opts=G_OPTS,
                     d_opts=D_OPTS,
                     dm_opts=DM_OPTS,
                     am_opts=AM_OPTS,
                     gan_targets=GAN_TARGETS
                    )

In [9]:
gan.get_generator()


(1, 1)
Out[9]:
<keras.engine.training.Model at 0x2b5c7d640400>

In [10]:
gan.get_discriminator()


WeightClip 0.02
WeightClip 0.02
WeightClip 0.02
WeightClip 0.02
WeightClip 0.02
WeightClip 0.02
Out[10]:
(<keras.engine.training.Model at 0x2b5c7d3b1e48>,
 <keras.engine.training.Model at 0x2b5c7d3b1940>)

In [11]:
gan.adversarial_compile(loss=LOSS,schedule=SCHEDULE)


wgan_loss

In [12]:
gan.get_generator().summary()


____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
G_64x5_c_input (InputLayer)      (None, 1, 1)          0                                            
____________________________________________________________________________________________________
G_64x5_input (InputLayer)        (None, 1, 1)          0                                            
____________________________________________________________________________________________________
G_64x5_all_inputs (Concatenate)  (None, 1, 2)          0           G_64x5_c_input[0][0]             
                                                                   G_64x5_input[0][0]               
____________________________________________________________________________________________________
G_64x5_up1_dense (Dense)         (None, 1, 256)        768         G_64x5_all_inputs[0][0]          
____________________________________________________________________________________________________
G_64x5_up1_activ (PReLU)         (None, 1, 256)        256         G_64x5_up1_dense[0][0]           
____________________________________________________________________________________________________
G_64x5_up2_dense (Dense)         (None, 1, 128)        32896       G_64x5_up1_activ[0][0]           
____________________________________________________________________________________________________
G_64x5_up2_activ (PReLU)         (None, 1, 128)        128         G_64x5_up2_dense[0][0]           
____________________________________________________________________________________________________
G_64x5_up3_dense (Dense)         (None, 1, 64)         8256        G_64x5_up2_activ[0][0]           
____________________________________________________________________________________________________
G_64x5_up3_activ (PReLU)         (None, 1, 64)         64          G_64x5_up3_dense[0][0]           
____________________________________________________________________________________________________
G_64x5_up4_dense (Dense)         (None, 1, 32)         2080        G_64x5_up3_activ[0][0]           
____________________________________________________________________________________________________
G_64x5_up4_activ (PReLU)         (None, 1, 32)         32          G_64x5_up4_dense[0][0]           
____________________________________________________________________________________________________
G_64x5_output (Dense)            (None, 1, 1)          33          G_64x5_up4_activ[0][0]           
____________________________________________________________________________________________________
G_64x5_add (Add)                 (None, 1, 1)          0           G_64x5_input[0][0]               
                                                                   G_64x5_output[0][0]              
====================================================================================================
Total params: 44,513
Trainable params: 44,513
Non-trainable params: 0
____________________________________________________________________________________________________

In [13]:
# gan.get_discriminator().summary()

In [14]:
# gan.am.summary()

In [15]:
# gan.dm.summary()

In [16]:
# gan.gan.summary()

In [17]:
# gan.gan.outputs, gan.gan.inputs

In [18]:
reload(base)

initial_epoch = 0
if hasattr(gan.model,"history"):
    initial_epoch = gan.model.history.epoch[-1]

do = dict(
    x_train=x_train,
    z_train=z_train,
    c_x_train=c_train,
    c_z_train=c_train,
          
    x_test=x_test,
    z_test=z_test,
    c_x_test=c_test,
    c_z_test=c_test,
    
    n_epochs=EPOCHS + initial_epoch +1,
    initial_epoch=initial_epoch,
    batch_size=BATCH_SIZE,
    plot_every=PLOT_EVERY,
    
     monitor_dir=MONITOR_DIR
)

base.MyGAN.adversarial_fit(gan,**do)


/users/musella/jupyter/GAN/GAN/base.py:179: UserWarning: The `nb_epoch` argument in `fit` has been renamed `epochs`.
  callbacks = [checkpoint,csv,tensorboard,plotter], **kwargs
calling WeightClip 0.02
calling WeightClip 0.02
calling WeightClip 0.02
calling WeightClip 0.02
calling WeightClip 0.02
calling WeightClip 0.02
calling WeightClip 0.02
calling WeightClip 0.02
calling WeightClip 0.02
calling WeightClip 0.02
calling WeightClip 0.02
calling WeightClip 0.02
Epoch 1/201
1499136/1500000 [============================>.] - ETA: 0s - loss: 0.1071 - discriminator_loss: 0.0493 - discriminator_model_2_loss: 0.0271 - discriminator_D_256x5_output_loss: -0.0313 - generator_loss: 0.0578 - generator_model_3_loss: -0.0271 - generator_D_256x5_output_loss: 0.0313           [-2.1707997829620935, 2.0175459014303696]
[-1.414192372560501, 2.0102448123566781]
[-0.95023128986358651, 1.9995040033785139]
[-0.91602052166432557, 2.2017407178878763]
[-0.91650795491705539, 2.4873297572135913]
[-0.92755347325415416, 2.6421164274215698]
[-0.92191305341739294, 2.7325156927108756]
[-0.12146046012639999, 0.19283143132925021]
1500000/1500000 [==============================] - 68s - loss: 0.1071 - discriminator_loss: 0.0493 - discriminator_model_2_loss: 0.0271 - discriminator_D_256x5_output_loss: -0.0313 - generator_loss: 0.0578 - generator_model_3_loss: -0.0271 - generator_D_256x5_output_loss: 0.0313    
Epoch 2/201
1500000/1500000 [==============================] - 18s - loss: 0.0919 - discriminator_loss: 0.0329 - discriminator_model_2_loss: 0.1340 - discriminator_D_256x5_output_loss: -0.1471 - generator_loss: 0.0590 - generator_model_3_loss: -0.1340 - generator_D_256x5_output_loss: 0.1471      ETA: 17s - loss: 0.0983 - discriminator_loss: 0.0451 - discriminator_model_2_loss: 0.0611 - discrim
Epoch 3/201
1500000/1500000 [==============================] - 18s - loss: 0.0817 - discriminator_loss: 0.0309 - discriminator_model_2_loss: 0.3779 - discriminator_D_256x5_output_loss: -0.3878 - generator_loss: 0.0508 - generator_model_3_loss: -0.3779 - generator_D_256x5_output_loss: 0.3878    - ETA: 7s - loss: 0.0835 - discriminator_loss: 0.0297 - discriminator_model_2_loss: 0.3708 - discriminator_D_
Epoch 4/201
1500000/1500000 [==============================] - 18s - loss: 0.0734 - discriminator_loss: 0.0289 - discriminator_model_2_loss: 0.3931 - discriminator_D_256x5_output_loss: -0.4008 - generator_loss: 0.0444 - generator_model_3_loss: -0.3931 - generator_D_256x5_output_loss: 0.4008    
Epoch 5/201
1500000/1500000 [==============================] - 19s - loss: 0.0670 - discriminator_loss: 0.0290 - discriminator_model_2_loss: 0.2754 - discriminator_D_256x5_output_loss: -0.2798 - generator_loss: 0.0379 - generator_model_3_loss: -0.2754 - generator_D_256x5_output_loss: 0.2798    
Epoch 6/201
1495040/1500000 [============================>.] - ETA: 0s - loss: 0.0618 - discriminator_loss: 0.0274 - discriminator_model_2_loss: 0.2274 - discriminator_D_256x5_output_loss: -0.2309 - generator_loss: 0.0344 - generator_model_3_loss: -0.2274 - generator_D_256x5_output_loss: 0.2309[-2.2633331060409549, 2.0175459014303696]
[-1.4161050796508792, 2.0102448123566781]
[-0.93014208780302732, 2.0282906174659674]
[-0.91602052166432557, 2.0732582211494406]
[-0.91650795491705539, 2.0927091002464269]
[-0.92755347325415416, 2.2680769681930557]
[-0.92191305341739294, 2.573496944727184]
[-0.22567487359046934, 0.15353203564882273]
1500000/1500000 [==============================] - 68s - loss: 0.0618 - discriminator_loss: 0.0274 - discriminator_model_2_loss: 0.2272 - discriminator_D_256x5_output_loss: -0.2307 - generator_loss: 0.0344 - generator_model_3_loss: -0.2272 - generator_D_256x5_output_loss: 0.2307    
Epoch 7/201
1500000/1500000 [==============================] - 19s - loss: 0.0571 - discriminator_loss: 0.0244 - discriminator_model_2_loss: 0.0987 - discriminator_D_256x5_output_loss: -0.1028 - generator_loss: 0.0327 - generator_model_3_loss: -0.0987 - generator_D_256x5_output_loss: 0.1028    
Epoch 8/201
1500000/1500000 [==============================] - 18s - loss: 0.0535 - discriminator_loss: 0.0225 - discriminator_model_2_loss: 0.0700 - discriminator_D_256x5_output_loss: -0.0743 - generator_loss: 0.0310 - generator_model_3_loss: -0.0700 - generator_D_256x5_output_loss: 0.0743    
Epoch 9/201
1500000/1500000 [==============================] - 18s - loss: 0.0504 - discriminator_loss: 0.0206 - discriminator_model_2_loss: 0.0469 - discriminator_D_256x5_output_loss: -0.0515 - generator_loss: 0.0298 - generator_model_3_loss: -0.0469 - generator_D_256x5_output_loss: 0.0515    
Epoch 10/201
1500000/1500000 [==============================] - 18s - loss: 0.0480 - discriminator_loss: 0.0192 - discriminator_model_2_loss: 0.0187 - discriminator_D_256x5_output_loss: -0.0235 - generator_loss: 0.0288 - generator_model_3_loss: -0.0187 - generator_D_256x5_output_loss: 0.0235    
Epoch 11/201
1499136/1500000 [============================>.] - ETA: 0s - loss: 0.0460 - discriminator_loss: 0.0179 - discriminator_model_2_loss: 0.0011 - discriminator_D_256x5_output_loss: -0.0062 - generator_loss: 0.0282 - generator_model_3_loss: -0.0011 - generator_D_256x5_output_loss: 0.0062        [-2.2086803197860712, 2.0175459014303696]
[-1.3325912714004517, 2.0102448123566781]
[-0.93014208780302732, 1.9995040033785139]
[-0.91602052166432557, 2.0682575464248645]
[-0.91650795491705539, 2.1943865060806269]
[-0.92755347325415416, 2.3464674711227422]
[-0.92191305341739294, 2.573496944727184]
[-0.1023152686655521, 0.14244886413216587]
1500000/1500000 [==============================] - 67s - loss: 0.0460 - discriminator_loss: 0.0179 - discriminator_model_2_loss: 0.0011 - discriminator_D_256x5_output_loss: -0.0062 - generator_loss: 0.0282 - generator_model_3_loss: -0.0011 - generator_D_256x5_output_loss: 0.0062    
Epoch 12/201
1500000/1500000 [==============================] - 18s - loss: 0.0442 - discriminator_loss: 0.0168 - discriminator_model_2_loss: -0.0225 - discriminator_D_256x5_output_loss: 0.0173 - generator_loss: 0.0274 - generator_model_3_loss: 0.0225 - generator_D_256x5_output_loss: -0.0173    
Epoch 13/201
1500000/1500000 [==============================] - 18s - loss: 0.0427 - discriminator_loss: 0.0162 - discriminator_model_2_loss: -0.0217 - discriminator_D_256x5_output_loss: 0.0165 - generator_loss: 0.0266 - generator_model_3_loss: 0.0217 - generator_D_256x5_output_loss: -0.0165    
Epoch 14/201
1500000/1500000 [==============================] - 18s - loss: 0.0415 - discriminator_loss: 0.0157 - discriminator_model_2_loss: -0.0227 - discriminator_D_256x5_output_loss: 0.0177 - generator_loss: 0.0258 - generator_model_3_loss: 0.0227 - generator_D_256x5_output_loss: -0.0177    
Epoch 15/201
1500000/1500000 [==============================] - 19s - loss: 0.0405 - discriminator_loss: 0.0157 - discriminator_model_2_loss: -0.0339 - discriminator_D_256x5_output_loss: 0.0294 - generator_loss: 0.0248 - generator_model_3_loss: 0.0339 - generator_D_256x5_output_loss: -0.0294    
Epoch 16/201
1499136/1500000 [============================>.] - ETA: 0s - loss: 0.0397 - discriminator_loss: 0.0158 - discriminator_model_2_loss: -0.0176 - discriminator_D_256x5_output_loss: 0.0136 - generator_loss: 0.0238 - generator_model_3_loss: 0.0176 - generator_D_256x5_output_loss: -0.0136[-2.1707997829620935, 2.0175459014303696]
[-1.2633458673954012, 2.0102448123566781]
[-0.93014208780302732, 1.9995040033785139]
[-0.91602052166432557, 2.0152273994879035]
[-0.91650795491705539, 2.0307035326957674]
[-0.92755347325415416, 2.2584376811981199]
[-0.92191305341739294, 2.573496944727184]
[0.026243185205385093, 0.2353737264871596]
1500000/1500000 [==============================] - 67s - loss: 0.0397 - discriminator_loss: 0.0158 - discriminator_model_2_loss: -0.0177 - discriminator_D_256x5_output_loss: 0.0137 - generator_loss: 0.0238 - generator_model_3_loss: 0.0177 - generator_D_256x5_output_loss: -0.0137    
Epoch 17/201
1500000/1500000 [==============================] - 18s - loss: 0.0389 - discriminator_loss: 0.0161 - discriminator_model_2_loss: -0.0140 - discriminator_D_256x5_output_loss: 0.0106 - generator_loss: 0.0228 - generator_model_3_loss: 0.0140 - generator_D_256x5_output_loss: -0.0106    
Epoch 18/201
 688128/1500000 [============>.................] - ETA: 10s - loss: 0.0384 - discriminator_loss: 0.0164 - discriminator_model_2_loss: -0.0044 - discriminator_D_256x5_output_loss: 0.0015 - generator_loss: 0.0221 - generator_model_3_loss: 0.0044 - generator_D_256x5_output_loss: -0.0015
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-18-2f4294fcbe9f> in <module>()
     22 )
     23 
---> 24 base.MyGAN.adversarial_fit(gan,**do)

~/jupyter/GAN/GAN/base.py in adversarial_fit(self, x_train, z_train, c_x_train, c_z_train, w_x_train, w_z_train, x_test, z_test, c_x_test, c_z_test, w_x_test, w_z_test, batch_size, n_epochs, plot_every, monitor_dir, checkpoint_every, **kwargs)
    177         self.model.fit( train_x, train_y,  sample_weight=train_w,
    178                         nb_epoch=n_epochs, batch_size=batch_size,
--> 179                         callbacks = [checkpoint,csv,tensorboard,plotter], **kwargs
    180         )
    181 

~/my-env/lib/python3.5/site-packages/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)
   1596                               initial_epoch=initial_epoch,
   1597                               steps_per_epoch=steps_per_epoch,
-> 1598                               validation_steps=validation_steps)
   1599 
   1600     def evaluate(self, x, y,

~/my-env/lib/python3.5/site-packages/keras/engine/training.py in _fit_loop(self, f, ins, out_labels, batch_size, epochs, verbose, callbacks, val_f, val_ins, shuffle, callback_metrics, initial_epoch, steps_per_epoch, validation_steps)
   1181                     batch_logs['size'] = len(batch_ids)
   1182                     callbacks.on_batch_begin(batch_index, batch_logs)
-> 1183                     outs = f(ins_batch)
   1184                     if not isinstance(outs, list):
   1185                         outs = [outs]

~/my-env/lib/python3.5/site-packages/keras_adversarial-0.0.3-py3.5.egg/keras_adversarial/adversarial_optimizers.py in train(_inputs)
    105                 self.iter = 0
    106             func = funcs[self.schedule[self.iter]]
--> 107             return func(_inputs)
    108 
    109         return train

~/my-env/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py in __call__(self, inputs)
   2271         updated = session.run(self.outputs + [self.updates_op],
   2272                               feed_dict=feed_dict,
-> 2273                               **self.session_kwargs)
   2274         return updated[:len(self.outputs)]
   2275 

/apps/dom/UES/jenkins/6.0.UP04/gpu/easybuild/software/TensorFlow/1.2.1-CrayGNU-17.08-cuda-8.0-python3/lib/python3.5/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    787     try:
    788       result = self._run(None, fetches, feed_dict, options_ptr,
--> 789                          run_metadata_ptr)
    790       if run_metadata:
    791         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/apps/dom/UES/jenkins/6.0.UP04/gpu/easybuild/software/TensorFlow/1.2.1-CrayGNU-17.08-cuda-8.0-python3/lib/python3.5/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
    995     if final_fetches or final_targets:
    996       results = self._do_run(handle, final_targets, final_fetches,
--> 997                              feed_dict_string, options, run_metadata)
    998     else:
    999       results = []

/apps/dom/UES/jenkins/6.0.UP04/gpu/easybuild/software/TensorFlow/1.2.1-CrayGNU-17.08-cuda-8.0-python3/lib/python3.5/site-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1130     if handle is None:
   1131       return self._do_call(_run_fn, self._session, feed_dict, fetch_list,
-> 1132                            target_list, options, run_metadata)
   1133     else:
   1134       return self._do_call(_prun_fn, self._session, handle, feed_dict,

/apps/dom/UES/jenkins/6.0.UP04/gpu/easybuild/software/TensorFlow/1.2.1-CrayGNU-17.08-cuda-8.0-python3/lib/python3.5/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
   1137   def _do_call(self, fn, *args):
   1138     try:
-> 1139       return fn(*args)
   1140     except errors.OpError as e:
   1141       message = compat.as_text(e.message)

/apps/dom/UES/jenkins/6.0.UP04/gpu/easybuild/software/TensorFlow/1.2.1-CrayGNU-17.08-cuda-8.0-python3/lib/python3.5/site-packages/tensorflow/python/client/session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
   1119         return tf_session.TF_Run(session, options,
   1120                                  feed_dict, fetch_list, target_list,
-> 1121                                  status, run_metadata)
   1122 
   1123     def _prun_fn(session, handle, feed_dict, fetch_list):

KeyboardInterrupt: 

In [ ]:
x_morphed = gan.get_generator().predict([c_test,z_test])[1]

x_p = gan.get_discriminator()[1].predict([c_test,x_test])
z_p   = gan.get_discriminator()[1].predict([c_test,x_morphed])

In [ ]:
reload(plotting)

quantiles = np.percentile(c_test,[0,5,20,40,60,80,95,100])

plotting.plot_summary_cond(x_test,c_test,x_morphed,c_test,z_test,x_p,z_p,
                           do_slices=False,c_bounds=quantiles)
plotting.plot_summary_cond(x_test,c_test,x_morphed,c_test,z_test,x_p,z_p,
                           do_slices=True,c_bounds=quantiles)

In [ ]:
gan.model.train_function

In [ ]:


In [ ]: