Hierarchical GANs for morphological and geometric trees

Imports


In [4]:
import numpy as np

# Keras
from keras.models import Sequential
from keras.layers.core import Dense, Reshape, Dropout, Activation
from keras.layers import Input, merge
from keras.models import Model
from keras.layers.wrappers import TimeDistributed
from keras.layers.recurrent import LSTM

# Other
import matplotlib.pyplot as plt
from copy import deepcopy
import os

%matplotlib inline

# Local
import McNeuron
import models

Example neuron


In [5]:
neuron_list = McNeuron.visualize.get_all_path(os.getcwd()+"/Data/Pyramidal/chen")
neuron = McNeuron.Neuron(file_format = 'swc', input_file=neuron_list[50])
McNeuron.visualize.plot_2D(neuron)


API demo

Draw a minibatch from generators [without context]


In [6]:
input_dim = 100
batch_size = 32
n_nodes = 10

random_code = np.random.randn(batch_size, 1, input_dim)

geom_model, morph_model = models.generator(use_context=False, n_nodes_out=n_nodes)

locations = geom_model.predict(random_code)
prufer_code = morph_model.predict(random_code)

print locations.shape
print prufer_code.shape


(32, 9, 3)
(32, 8, 10)

In [7]:
tmp = prufer_code[0, :, :]
np.sum(tmp, axis=1)


Out[7]:
array([ 1.        ,  0.99999994,  1.        ,  1.        ,  1.        ,
        1.        ,  1.        ,  1.        ], dtype=float32)

Draw a minibatch from generators [with context]


In [8]:
geom_model_2, morph_model_2 = models.generator(use_context=True, n_nodes_in=10, n_nodes_out=20)
random_code_2 = np.random.randn(batch_size, 1, input_dim)

locations_2 = geom_model_2.predict([locations, prufer_code, random_code_2])
prufer_code_2 = morph_model_2.predict([locations, prufer_code, random_code_2])

print locations_2.shape
print prufer_code_2.shape


(32, 19, 3)
(32, 18, 20)

Run discriminator on sampled minibatch [$n = 10$ nodes]


In [9]:
disc_model = models.discriminator(n_nodes_in=10)
disc_labels = disc_model.predict([locations, prufer_code])
print disc_labels.shape
print np.squeeze(disc_labels)


(32, 1, 1)
[ 0.19992203  0.14345191  0.10469893  0.21955609  0.22904482  0.28652555
  0.20005496  0.16769955  0.03858241  0.24209709 -0.02403441  0.19252911
  0.30866897  0.34500864  0.29839811  0.12077656  0.29376683  0.3023141
  0.05244704  0.1476939   0.08699895  0.54631227  0.11523922 -0.02631815
  0.45277524  0.18413661 -0.00462817  0.05520289 -0.02865714  0.23677373
  0.14159074  0.01556951]

Run discriminator on sampled minibatch [$n = 20$ nodes]


In [10]:
disc_model_2 = models.discriminator(n_nodes_in=20)
disc_labels = disc_model_2.predict([locations_2, prufer_code_2])
print disc_labels.shape
print np.squeeze(disc_labels)


(32, 1, 1)
[ 0.23277178  0.43031889 -0.14755693 -0.16495126  0.0477164   0.27254003
 -0.04037338 -0.11756077 -0.0010753  -0.11512145 -0.04064863  0.27470613
  0.07390958 -0.15373152 -0.27612317 -0.02276942  0.24479601 -0.15658778
 -0.29937071 -0.15025449  0.11974847  0.392048    0.13897984 -0.15488473
 -0.04696935 -0.07366657 -0.03407723  0.06505871 -0.0534447   0.33787423
 -0.10418761 -0.23633122]

Run GAN model [without context, $n = 10$ nodes]


In [11]:
gan_model = models.discriminator_on_generators(geom_model,
                                               morph_model,
                                               disc_model,
                                               input_dim=100,
                                               n_nodes_in=10,
                                               use_context=False)
print gan_model.input_shape, gan_model.output_shape


(None, 1, 100) (None, 1, 1)

Run GAN model [with context, $n = 20$ nodes]


In [12]:
gan_model = models.discriminator_on_generators(geom_model_2,
                                               morph_model_2,
                                               disc_model_2,
                                               input_dim=100,
                                               n_nodes_in=10,
                                               use_context=True)
print gan_model.input_shape, gan_model.output_shape


[(None, 9, 3), (None, 8, 10), (None, 1, 100)] (None, 1, 1)

Training

Useful functions


In [22]:
def clip_weights(model, weight_constraint):
    """
    Clip weights of a keras model to be bounded by given constraints.
    
    Parameters
    ----------
    model: keras model object
        model for which weights need to be clipped
    weight_constraint:
    
    Returns
    -------
    model: keras model object
        model with clipped weights
    """
    for l in model.layers:
        weights = l.get_weights()
        weights = [np.clip(w, weight_constraint[0], weight_constraint[1]) for w in weights]
        l.set_weights(weights)
    return model

In [24]:
def get_batch():
    return None

In [25]:
def gen_batch(batch_size=batch_size,
              n_nodes=n_nodes,
              level=level,
              input_dim=input_dim,
              geom_model=geom_model,
              morph_model=morph_model):
    """
    Generate a batch of samples from 
    geometry and morphology generator 
    networks at desired levels of hierarchy
    
    Parameters
    ----------
    batch_size: int
        batch size
    n_nodes: list of ints
        number of nodes at each level
    level: int
        indicator of level in the hierarchy
    input_dim: int
        dimensionality of noise input
    geom_model: list of keras objects
        geometry generator for each level in the hierarchy
    morph_model: list of keras objects
        morphology generator for each level in the hierarchy
        
    Returns
    -------
    locations: float (batch_size x 3 x n_nodes[level] - 1)
        batch of generated locations
    prufer: float (batch_size x n_nodes[level] x n_nodes[level] - 2)
        batch of generated morphology
    """
    for l in range(0, level):
        
        # Generate noise code
        noise_code = np.random.randn(batch_size, 1, input_dim)
        
        if l == 0:
            # Generate geometry and morphology
            locations = geom_model[l].predict(noise_code)
            prufer = morph_model[l].predict(noise_code)
        else:
            # Assign previous level's geometry and morphology as priors for the next level
            locations_prior = locations
            prufer_prior = prufer
            
            # Generate geometry and morphology conditioned on the previous level
            locations = geom_model[l].predict([locations_prior, prufer_prior, noise_code])
            prufer = morph_model[l].predict([locations_prior, prufer_prior, noise_code])

    return locations, prufer

Global parameters


In [19]:
n_levels = 3
n_nodes = [10, 20, 40]

input_dim = 100

n_epochs = 25
batch_size = 64
n_batch_per_epoch = 100
d_iters = 100
lr =  0.00005

weight_constraint = [-0.01, 0.01]

Initialize model objects at all levels


In [18]:
# ---------------------------------------
# Initialize model objects at all levels
# ---------------------------------------
geom_model = list()
morph_model = list()
disc_model = list()
gan_model = list()

for level in range(n_levels):
    
    # Discriminator
    d_model = models.discriminator(n_nodes_in=n_nodes[level])
    
    # Generators and GANs
    # If we are in the first level, no context
    if level == 0:
        g_model, m_model = models.generator(use_context=False,
                                            n_nodes_out=n_nodes[level])
        gd_model = models.discriminator_on_generators(g_model,
                                                      m_model,
                                                      d_model,
                                                      input_dim=input_dim,
                                                      use_context=False)        
    # In subsequent levels, we need context
    else:
        g_model, m_model = models.generator(use_context=True,
                                            n_nodes_in=n_nodes[level - 1],
                                            n_nodes_out=n_nodes[level])
        gd_model = models.discriminator_on_generators(g_model,
                                                      m_model,
                                                      d_model,
                                                      input_dim=input_dim,
                                                      n_nodes_in=n_nodes[level - 1],
                                                      use_context=True)
        
    # Collect all models into a list
    disc_model.append(d_model)
    geom_model.append(g_model)
    morph_model.append(m_model)
    gan_model.append(gd_model)

Optimizers


In [21]:
from keras.optimizers import RMSprop
optim = RMSprop(lr=lr)

Loop


In [ ]:
for level in range(n_levels):
    
    # ---------------
    # Compile models
    # ---------------
    g_model = geom_model[level]
    m_model = morph_model[level]
    d_model = disc_model[level]
    gd_model = gan_model[level]
    
    g_model.compile(loss='mse', optimizer=optim)
    m_model.compile(loss='mse', optimizer=optim)
    d_model.trainable = False
    gd_model.compile(loss=models.wasserstein_loss, optimizer=optim)
    d_model.trainable = True
    d_model.compile(loss=models.wasserstein_loss, optimizer=optim)
    
    # -----------------
    # Loop over epochs
    # -----------------
    for e in range(n_epochs):
    
        batch_counter = 1
        while batch_counter < n_batch_per_epoch:

        list_d_loss_real = list()
        list_d_loss_gen = list()

        # ----------------------------
        # Step 1: Train discriminator
        # ----------------------------
        for d_iter in range(d_iters):
            
            # Clip discriminator weights
            d_model = clip_weights(d_model, weight_constraint)

            # Create a batch to feed the discriminator model
            X_locations_real, X_prufer_real = \
                batch_utils.get_batch(data=data,
                                      batch_size=batch_size,
                                      batch_counter=batch_counter,
                                      n_nodes=n_nodes[level])
                
            X_locations_gen, X_prufer_gen = \
                batch_utils.gen_batch(batch_size=batch_size,
                                      n_nodes=n_nodes,
                                      level=level,
                                      input_dim=input_dim,
                                      geom_model=geom_model,
                                      morph_model=morph_model)
                
            # Update the discriminator
            disc_loss_real = \
                d_model.train_on_batch([X_locations_real, X_prufer_real],
                                       -np.ones(X_locations_real.shape[0]))
            disc_loss_gen = \
                d_model.train_on_batch([X_locations_gen, X_prufer_gen],
                                       np.ones(X_locations_gen.shape[0]))
            list_disc_loss_real.append(disc_loss_real)
            list_disc_loss_gen.append(disc_loss_gen)
            
        # ------------------------
        # Step 2: Train generator
        # ------------------------
        X_locations_gen, X_prufer_gen = \
            batch_utils.gen_batch(batch_size=batch_size,
                                  n_nodes=n_nodes,
                                  level=level,
                                  input_dim=input_dim,
                                  geom_model=geom_model,
                                  morph_model=morph_model)
            
        # Freeze the discriminator
        d_model.trainable = False
        gen_loss = \
            gd_model.train_on_batch([X_locations_gen, X_prufer_gen],
                                    -np.ones(X_locations_gen.shape[0]))
            
        # Unfreeze the discriminator
        d_model.trainable = True

        # Housekeeping
        gen_iterations += 1
        batch_counter += 1

        # Save images for visualization (say 2 times per epoch)
        # TODO
        
    # Save model weights (every few epochs)
    # TODO

In [ ]:
# Load and normalize data
    X_real_train = data_utils.load_image_dataset(dset, img_dim, image_dim_ordering)

    # Get the full real image dimension
    img_dim = X_real_train.shape[-3:]

    # Create optimizers
    opt_G = data_utils.get_optimizer(opt_G, lr_G)
    opt_D = data_utils.get_optimizer(opt_D, lr_D)

    #######################
    # Load models
    #######################
    noise_dim = (noise_dim,)
    if generator == "upsampling":
        generator_model = models.generator_upsampling(noise_dim, img_dim, bn_mode, dset=dset)
    else:
        generator_model = models.generator_deconv(noise_dim, img_dim, bn_mode, batch_size, dset=dset)
    discriminator_model = models.discriminator(img_dim, bn_mode, dset=dset)
    DCGAN_model = models.DCGAN(generator_model, discriminator_model, noise_dim, img_dim)

    ############################
    # Compile models
    ############################
    generator_model.compile(loss='mse', optimizer=opt_G)
    discriminator_model.trainable = False
    DCGAN_model.compile(loss=models.wasserstein, optimizer=opt_G)
    discriminator_model.trainable = True
    discriminator_model.compile(loss=models.wasserstein, optimizer=opt_D)

    # Global iteration counter for generator updates
    gen_iterations = 0

    #################
    # Start training
    ################
    for e in range(nb_epoch):
        # Initialize progbar and batch counter
        progbar = generic_utils.Progbar(epoch_size)
        batch_counter = 1
        start = time.time()

        while batch_counter < n_batch_per_epoch:

            if gen_iterations < 25 or gen_iterations % 500 == 0:
                disc_iterations = 100
            else:
                disc_iterations = kwargs["disc_iterations"]

            ###################################
            # 1) Train the critic / discriminator
            ###################################
            list_disc_loss_real = []
            list_disc_loss_gen = []
            for disc_it in range(disc_iterations):

                # Clip discriminator weights
                d_model = clip_weights(d_model, weight_constraint)
                

                X_real_batch = next(data_utils.gen_batch(X_real_train, batch_size))

                
                
                # Update the discriminator
                disc_loss_real = d_model.train_on_batch([X_locations_real, X_prufer_real],
                                                        -np.ones(X_locations_real.shape[0]))
                disc_loss_gen = d_model.train_on_batch(X_disc_gen, np.ones(X_disc_gen.shape[0]))
                list_disc_loss_real.append(disc_loss_real)
                list_disc_loss_gen.append(disc_loss_gen)

            #######################
            # 2) Train the generator
            #######################
            X_gen = X_gen = data_utils.sample_noise(noise_scale, batch_size, noise_dim)

            # Freeze the discriminator
            discriminator_model.trainable = False
            gen_loss = DCGAN_model.train_on_batch(X_gen, -np.ones(X_gen.shape[0]))
            # Unfreeze the discriminator
            discriminator_model.trainable = True

            gen_iterations += 1
            batch_counter += 1
            progbar.add(batch_size, values=[("Loss_D", -np.mean(list_disc_loss_real) - np.mean(list_disc_loss_gen)),
                                            ("Loss_D_real", -np.mean(list_disc_loss_real)),
                                            ("Loss_D_gen", np.mean(list_disc_loss_gen)),
                                            ("Loss_G", -gen_loss)])

            # Save images for visualization ~2 times per epoch
            if batch_counter % (n_batch_per_epoch / 2) == 0:
                data_utils.plot_generated_batch(X_real_batch, generator_model,
                                                batch_size, noise_dim, image_dim_ordering)

        print('\nEpoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start))

        # Save model weights (by default, every 5 epochs)
        data_utils.save_model_weights(generator_model, discriminator_model, DCGAN_model, e)

In [44]:
#data['geometry'][0][20] (3 x 19)
from sklearn.preprocessing import OneHotEncoder
enc = OneHotEncoder(n_values=n_nodes[0])
tmp = enc.fit_transform([2, 3, 4, 3, 0, 0, 1, 7]).toarray()


/Users/pavanramkumar/anaconda2/lib/python2.7/site-packages/sklearn/utils/validation.py:386: DeprecationWarning: Passing 1d arrays as data is deprecated in 0.17 and willraise ValueError in 0.19. Reshape your data either using X.reshape(-1, 1) if your data has a single feature or X.reshape(1, -1) if it contains a single sample.
  DeprecationWarning)

In [35]:
tmp2 = np.reshape(tmp, [8, 10, 32])
tmp2.T


Out[35]:
array([[ 0.,  0.,  0.,  0.,  1.,  1.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.],
       [ 1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  1.,  0.,  1.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]])

In [ ]:
prufer_code = ()
X_prufer_real = enc.fit_transform(prufer_code).to_array()