Hierarchical GANs for morphological and geometric trees

Imports


In [1]:
import numpy as np

# Keras
from keras.models import Sequential
from keras.layers.core import Dense, Reshape, Dropout, Activation
from keras.layers import Input, merge
from keras.models import Model
from keras.layers.wrappers import TimeDistributed
from keras.layers.recurrent import LSTM

# Other
import matplotlib.pyplot as plt
from copy import deepcopy
import os
import pickle

%matplotlib inline

# Local
import McNeuron
import models_generate_parents as models
import train_one_by_one as train
import batch_utils
import data_transforms


Using Theano backend.

Example neuron


In [2]:
neuron_list = McNeuron.visualize.get_all_path(os.getcwd()+"/Data/Pyramidal/chen")
neuron = McNeuron.Neuron(file_format = 'swc', input_file=neuron_list[50])
McNeuron.visualize.plot_2D(neuron)


/Users/pavanramkumar/anaconda2/lib/python2.7/site-packages/scipy/sparse/compressed.py:730: SparseEfficiencyWarning: Changing the sparsity structure of a csr_matrix is expensive. lil_matrix is more efficient.
  SparseEfficiencyWarning)

Training

Load Data


In [3]:
training_data = pickle.load(open("/Users/pavanramkumar/Dropbox/HG-GAN/03-Data/synthetic_Two_segment_model_Different_directions_Parent_distance_n20_parent_id.p", "rb"))
#training_data = pickle.load(open("/Users/pavanramkumar/Dropbox/HG-GAN/03-Data/train4.p", "rb"))

In [4]:
print training_data['morphology']['n20'].shape
print training_data['geometry']['n20'].shape


(50000, 19)
(50000, 19, 3)

In [5]:
import data_transforms
v = np.zeros([training_data['morphology']['n40'].shape[0],39])
for i in range(training_data['morphology']['n40'].shape[0]):
    a = data_transforms.decode_prufer(list(training_data['morphology']['n40'][i,:]))
    a = np.array(a)
    v[i,:] = a[1:]
training_data['morphology']['n40'] = v


---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-5-ef393240b6f6> in <module>()
      1 import data_transforms
----> 2 v = np.zeros([training_data['morphology']['n40'].shape[0],39])
      3 for i in range(training_data['morphology']['n40'].shape[0]):
      4     a = data_transforms.decode_prufer(list(training_data['morphology']['n40'][i,:]))
      5     a = np.array(a)

KeyError: 'n40'

Global parameters


In [5]:
n_levels = 1
n_nodes = [20]

input_dim = 100

n_epochs = 5
batch_size = 32
n_batch_per_epoch = np.floor(training_data['morphology']['n20'].shape[0]/batch_size).astype(int)
d_iters = 100
lr_discriminator =  0.001
lr_generator = 0.001
train_loss = 'wasserstein_loss'

rule = 'none'
train_one_by_one = False
weight_constraint = [-0.03, 0.03]

Run


In [6]:
geom_model, cond_geom_model, morph_model, cond_morph_model, disc_model, gan_model = \
    train.train_model(training_data=training_data,
                      n_levels=n_levels,
                      n_nodes=n_nodes,
                      input_dim=input_dim,
                      n_epochs=n_epochs,
                      batch_size=batch_size,
                      n_batch_per_epoch=n_batch_per_epoch,
                      d_iters=d_iters,
                      lr_discriminator=lr_discriminator,
                      lr_generator=lr_generator,
                      weight_constraint=weight_constraint,
                      rule=rule,
                      train_one_by_one=train_one_by_one,
                      train_loss=train_loss,
                      verbose=True)


____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
input_1 (InputLayer)             (None, 19, 3)         0                                            
____________________________________________________________________________________________________
input_2 (InputLayer)             (None, 19, 20)        0                                            
____________________________________________________________________________________________________
merge_1 (Merge)                  (None, 19, 23)        0           input_1[0][0]                    
                                                                   input_2[0][0]                    
____________________________________________________________________________________________________
lstm_1 (LSTM)                    (None, 19, 20)        3520        merge_1[0][0]                    
____________________________________________________________________________________________________
reshape_1 (Reshape)              (None, 1, 380)        0           lstm_1[0][0]                     
____________________________________________________________________________________________________
embedding (Dense)                (None, 1, 100)        38100       reshape_1[0][0]                  
____________________________________________________________________________________________________
dense_1 (Dense)                  (None, 1, 50)         5050        embedding[0][0]                  
____________________________________________________________________________________________________
dense_2 (Dense)                  (None, 1, 50)         2550        dense_1[0][0]                    
____________________________________________________________________________________________________
dense_3 (Dense)                  (None, 1, 50)         2550        dense_2[0][0]                    
____________________________________________________________________________________________________
dense_4 (Dense)                  (None, 1, 1)          51          dense_3[0][0]                    
====================================================================================================
Total params: 51,821
Trainable params: 51,821
Non-trainable params: 0
____________________________________________________________________________________________________
____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
input_4 (InputLayer)             (None, 19, 20)        0                                            
____________________________________________________________________________________________________
lstm_3 (LSTM)                    (None, 19, 20)        3280        input_4[0][0]                    
____________________________________________________________________________________________________
reshape_3 (Reshape)              (None, 1, 380)        0           lstm_3[0][0]                     
____________________________________________________________________________________________________
noise_input (InputLayer)         (None, 1, 100)        0                                            
____________________________________________________________________________________________________
morphology_embedding (Dense)     (None, 1, 100)        38100       reshape_3[0][0]                  
____________________________________________________________________________________________________
merge_2 (Merge)                  (None, 1, 100)        0           noise_input[0][0]                
                                                                   morphology_embedding[0][0]       
____________________________________________________________________________________________________
dense_7 (Dense)                  (None, 1, 57)         5757        merge_2[0][0]                    
____________________________________________________________________________________________________
dense_8 (Dense)                  (None, 1, 57)         3306        dense_7[0][0]                    
____________________________________________________________________________________________________
reshape_5 (Reshape)              (None, 19, 3)         0           dense_8[0][0]                    
====================================================================================================
Total params: 50,443
Trainable params: 50,443
Non-trainable params: 0
____________________________________________________________________________________________________
____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
noise_input (InputLayer)         (None, 1, 100)        0                                            
____________________________________________________________________________________________________
dense_9 (Dense)                  (None, 1, 380)        38380       noise_input[0][0]                
____________________________________________________________________________________________________
dense_10 (Dense)                 (None, 1, 380)        144780      dense_9[0][0]                    
____________________________________________________________________________________________________
dense_11 (Dense)                 (None, 1, 380)        144780      dense_10[0][0]                   
____________________________________________________________________________________________________
reshape_6 (Reshape)              (None, 19, 20)        0           dense_11[0][0]                   
____________________________________________________________________________________________________
lambda_1 (Lambda)                (None, 19, 20)        0           reshape_6[0][0]                  
====================================================================================================
Total params: 327,940
Trainable params: 327,940
Non-trainable params: 0
____________________________________________________________________________________________________

====================
Level #0
====================

    Epoch #0

    After 100 iterations
        Discriminator Loss                         = -14.2382354736

    Generator_Loss: -3.69740653038
2
    After 100 iterations
        Discriminator Loss                         = -13.7442436218

    Generator_Loss: -1.89559566975
3
    After 100 iterations
        Discriminator Loss                         = -15.7646398544

    Generator_Loss: -5.58012962341
4
    After 100 iterations
        Discriminator Loss                         = -15.4411010742

    Generator_Loss: -2.52191305161
5
    After 100 iterations
        Discriminator Loss                         = -15.2003860474

    Generator_Loss: -5.48935079575
6
    After 100 iterations
        Discriminator Loss                         = -12.3217191696

    Generator_Loss: -4.74316692352
7
    After 100 iterations
        Discriminator Loss                         = -12.1488056183

    Generator_Loss: -7.16605472565
8
    After 100 iterations
        Discriminator Loss                         = -11.830160141

    Generator_Loss: -1.03904104233
9
    After 100 iterations
        Discriminator Loss                         = -11.6392364502

    Generator_Loss: -6.39512014389
10
    After 100 iterations
        Discriminator Loss                         = -11.3513021469

    Generator_Loss: -1.63504087925
11
    After 100 iterations
        Discriminator Loss                         = -15.0059127808

    Generator_Loss: -0.366696745157
12
    After 100 iterations
        Discriminator Loss                         = -10.1584272385

    Generator_Loss: -6.79331588745
13
    After 100 iterations
        Discriminator Loss                         = -12.7716503143

    Generator_Loss: -5.34430503845
14
    After 100 iterations
        Discriminator Loss                         = -11.4451560974

    Generator_Loss: -6.30458927155
15
    After 100 iterations
        Discriminator Loss                         = -10.7940979004

    Generator_Loss: -5.87987136841
16
    After 100 iterations
        Discriminator Loss                         = -12.2619934082

    Generator_Loss: -0.959267616272
17
    After 100 iterations
        Discriminator Loss                         = -13.3002824783

    Generator_Loss: 3.01012063026
18
    After 100 iterations
        Discriminator Loss                         = -12.7603206635

    Generator_Loss: -0.864161252975
19
    After 100 iterations
        Discriminator Loss                         = -11.1565246582

    Generator_Loss: -6.26010942459
20
    After 100 iterations
        Discriminator Loss                         = -10.9535188675

    Generator_Loss: -3.99812245369
21
    After 100 iterations
        Discriminator Loss                         = -11.9022083282

    Generator_Loss: -2.2596988678
22
    After 100 iterations
        Discriminator Loss                         = -11.7515306473

    Generator_Loss: -7.31093788147
23
    After 100 iterations
        Discriminator Loss                         = -10.6445484161

    Generator_Loss: -8.05163860321
24
    After 100 iterations
        Discriminator Loss                         = -10.3215379715

    Generator_Loss: -5.47783088684
25
     Level #0 Epoch #0 Batch #25
    After 100 iterations
        Discriminator Loss                         = -10.6872310638

    Generator_Loss: -4.37649679184
26
    After 100 iterations
        Discriminator Loss                         = -10.9743595123

    Generator_Loss: -7.20958375931
27
    After 100 iterations
        Discriminator Loss                         = -11.0246725082

    Generator_Loss: -3.98718523979
28
    After 100 iterations
        Discriminator Loss                         = -11.30311203

    Generator_Loss: -2.98023509979
29
    After 100 iterations
        Discriminator Loss                         = -11.2656459808

    Generator_Loss: -5.19827365875
30
    After 100 iterations
        Discriminator Loss                         = -14.3320446014

    Generator_Loss: -1.75694787502
31
    After 100 iterations
        Discriminator Loss                         = -13.2075080872

    Generator_Loss: -1.02454304695
32
    After 100 iterations
        Discriminator Loss                         = -13.4137172699

    Generator_Loss: -0.847509264946
33
    After 100 iterations
        Discriminator Loss                         = -12.2826910019

    Generator_Loss: -1.67963767052
34
    After 100 iterations
        Discriminator Loss                         = -10.9705181122

    Generator_Loss: -7.04352283478
35
    After 100 iterations
        Discriminator Loss                         = -11.1028518677

    Generator_Loss: -5.16757392883
36
    After 100 iterations
        Discriminator Loss                         = -12.7657461166

    Generator_Loss: -5.59972143173
37
    After 100 iterations
        Discriminator Loss                         = -9.83433437347

    Generator_Loss: -4.78792619705
38
    After 100 iterations
        Discriminator Loss                         = -11.7467746735

    Generator_Loss: -8.28816223145
39
    After 100 iterations
        Discriminator Loss                         = -9.94538211823

    Generator_Loss: -12.0971603394
40
    After 100 iterations
        Discriminator Loss                         = -13.2280483246

    Generator_Loss: -3.04316997528
41
    After 100 iterations
        Discriminator Loss                         = -10.4019823074

    Generator_Loss: -6.06558465958
42
    After 100 iterations
        Discriminator Loss                         = -10.1470928192

    Generator_Loss: -8.37930202484
43
    After 100 iterations
        Discriminator Loss                         = -10.4547834396

    Generator_Loss: -6.12994289398
44
    After 100 iterations
        Discriminator Loss                         = -9.57119655609

    Generator_Loss: -7.75011444092
45
    After 100 iterations
        Discriminator Loss                         = -9.69299888611

    Generator_Loss: -5.12514352798
46
    After 100 iterations
        Discriminator Loss                         = -11.178855896

    Generator_Loss: -3.78681445122
47
    After 100 iterations
        Discriminator Loss                         = -9.37148761749

    Generator_Loss: -4.72969293594
48
    After 100 iterations
        Discriminator Loss                         = -11.8976182938

    Generator_Loss: -4.81758213043
49
    After 100 iterations
        Discriminator Loss                         = -11.7604579926

    Generator_Loss: -6.39258050919
50
     Level #0 Epoch #0 Batch #50
    After 100 iterations
        Discriminator Loss                         = -10.9327707291

    Generator_Loss: -5.70575714111
51
    After 100 iterations
        Discriminator Loss                         = -10.0011110306

    Generator_Loss: -7.84747982025
52
    After 100 iterations
        Discriminator Loss                         = -13.411986351

    Generator_Loss: 1.71344876289
53
    After 100 iterations
        Discriminator Loss                         = -12.6349906921

    Generator_Loss: -0.404529243708
54
    After 100 iterations
        Discriminator Loss                         = -9.50937652588

    Generator_Loss: -0.459711462259
55
    After 100 iterations
        Discriminator Loss                         = -10.0337629318

    Generator_Loss: -4.34375572205
56
    After 100 iterations
        Discriminator Loss                         = -12.2447328568

    Generator_Loss: -3.63270568848
57
    After 100 iterations
        Discriminator Loss                         = -10.8586874008

    Generator_Loss: -6.31936693192
58
    After 100 iterations
        Discriminator Loss                         = -9.59605503082

    Generator_Loss: -5.41495752335
59
    After 100 iterations
        Discriminator Loss                         = -10.2996797562

    Generator_Loss: -8.40937995911
60
    After 100 iterations
        Discriminator Loss                         = -11.2986383438

    Generator_Loss: -0.174078673124
61
    After 100 iterations
        Discriminator Loss                         = -9.51883125305

    Generator_Loss: -6.54342269897
62
    After 100 iterations
        Discriminator Loss                         = -14.7169599533

    Generator_Loss: -5.64821958542
63
    After 100 iterations
        Discriminator Loss                         = -11.000087738

    Generator_Loss: 0.0721491575241
64
    After 100 iterations
        Discriminator Loss                         = -9.25157737732

    Generator_Loss: -6.67536830902
65
    After 100 iterations
        Discriminator Loss                         = -9.83593845367

    Generator_Loss: -11.3090553284
66
    After 100 iterations
        Discriminator Loss                         = -11.9845724106

    Generator_Loss: -5.67347574234
67
    After 100 iterations
        Discriminator Loss                         = -11.2854127884

    Generator_Loss: -0.186619907618
68
    After 100 iterations
        Discriminator Loss                         = -9.99443626404

    Generator_Loss: -6.33344697952
69
    After 100 iterations
        Discriminator Loss                         = -12.5954780579

    Generator_Loss: -1.89248490334
70
    After 100 iterations
        Discriminator Loss                         = -11.7959337234

    Generator_Loss: -0.777896523476
71
    After 100 iterations
        Discriminator Loss                         = -10.326125145

    Generator_Loss: -3.91578388214
72
    After 100 iterations
        Discriminator Loss                         = -9.55623817444

    Generator_Loss: -3.34179139137
73
    After 100 iterations
        Discriminator Loss                         = -9.57527828217

    Generator_Loss: -0.159311443567
74
    After 100 iterations
        Discriminator Loss                         = -8.80860137939

    Generator_Loss: -8.69229793549
75
     Level #0 Epoch #0 Batch #75
    After 100 iterations
        Discriminator Loss                         = -9.90274333954

    Generator_Loss: -1.82362270355
76
    After 100 iterations
        Discriminator Loss                         = -9.57532978058

    Generator_Loss: -6.51833248138
77
    After 100 iterations
        Discriminator Loss                         = -10.5961771011

    Generator_Loss: -5.02942180634
78
    After 100 iterations
        Discriminator Loss                         = -9.2793712616

    Generator_Loss: -7.76452159882
79
    After 100 iterations
        Discriminator Loss                         = -11.052696228

    Generator_Loss: -9.57262039185
80
    After 100 iterations
        Discriminator Loss                         = -9.38230133057

    Generator_Loss: -7.85399913788
81
    After 100 iterations
        Discriminator Loss                         = -8.41446304321

    Generator_Loss: -8.45682048798
82
    After 100 iterations
        Discriminator Loss                         = -7.59083366394

    Generator_Loss: -4.58839607239
83
    After 100 iterations
        Discriminator Loss                         = -8.43841362

    Generator_Loss: -11.4100236893
84
    After 100 iterations
        Discriminator Loss                         = -10.3394947052

    Generator_Loss: -16.3781833649
85
    After 100 iterations
        Discriminator Loss                         = -10.4670448303

    Generator_Loss: -1.06459629536
86
    After 100 iterations
        Discriminator Loss                         = -10.1561126709

    Generator_Loss: -7.75417900085
87
    After 100 iterations
        Discriminator Loss                         = -11.2598056793

    Generator_Loss: -5.17526531219
88
    After 100 iterations
        Discriminator Loss                         = -9.5054101944

    Generator_Loss: -9.37708091736
89
    After 100 iterations
        Discriminator Loss                         = -9.68503093719

    Generator_Loss: -6.06930541992
90
    After 100 iterations
        Discriminator Loss                         = -9.77408981323

    Generator_Loss: -6.39395856857
91
    After 100 iterations
        Discriminator Loss                         = -10.3027820587

    Generator_Loss: -9.76529216766
92
    After 100 iterations
        Discriminator Loss                         = -8.27371883392

    Generator_Loss: -5.69399404526
93
    After 100 iterations
        Discriminator Loss                         = -9.13221073151

    Generator_Loss: -7.27712631226
94
    After 100 iterations
        Discriminator Loss                         = -10.0003671646

    Generator_Loss: -7.09442996979
95
    After 100 iterations
        Discriminator Loss                         = -7.58400535583

    Generator_Loss: -12.5773096085
96
    After 100 iterations
        Discriminator Loss                         = -10.699010849

    Generator_Loss: -2.04175329208
97
    After 100 iterations
        Discriminator Loss                         = -8.37799358368

    Generator_Loss: -7.10191965103
98
    After 100 iterations
        Discriminator Loss                         = -11.2104682922

    Generator_Loss: -13.6704158783
99
    After 100 iterations
        Discriminator Loss                         = -9.71674060822

    Generator_Loss: -4.81488180161
100
     Level #0 Epoch #0 Batch #100
    After 100 iterations
        Discriminator Loss                         = -8.94493865967

    Generator_Loss: -2.69774723053
101
    After 100 iterations
        Discriminator Loss                         = -9.83259487152

    Generator_Loss: -6.65986251831
102
    After 100 iterations
        Discriminator Loss                         = -8.54524326324

    Generator_Loss: -6.7202129364
103
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-6-b84be581c68b> in <module>()
     13                       train_one_by_one=train_one_by_one,
     14                       train_loss=train_loss,
---> 15                       verbose=True)

/Users/pavanramkumar/Projects/34-HGGAN/McNeuron/train_one_by_one.py in train_model(training_data, n_levels, n_nodes, input_dim, n_epochs, batch_size, n_batch_per_epoch, d_iters, lr_discriminator, lr_generator, weight_constraint, rule, train_one_by_one, train_loss, verbose)
    353                         d_model.train_on_batch([X_locations,
    354                                                 X_prufer],
--> 355                                                y)
    356 
    357                     list_d_loss.append(disc_loss)

/Users/pavanramkumar/anaconda2/lib/python2.7/site-packages/keras/engine/training.pyc in train_on_batch(self, x, y, sample_weight, class_weight)
   1308             sample_weight=sample_weight,
   1309             class_weight=class_weight,
-> 1310             check_batch_axis=True)
   1311         if self.uses_learning_phase and not isinstance(K.learning_phase, int):
   1312             ins = x + y + sample_weights + [1.]

/Users/pavanramkumar/anaconda2/lib/python2.7/site-packages/keras/engine/training.pyc in _standardize_user_data(self, x, y, sample_weight, class_weight, check_batch_axis, batch_size)
   1042         check_array_lengths(x, y, sample_weights)
   1043         check_loss_and_target_compatibility(y, self.loss_functions, self.internal_output_shapes)
-> 1044         if self.stateful and batch_size:
   1045             if x[0].shape[0] % batch_size != 0:
   1046                 raise ValueError('In a stateful network, '

/Users/pavanramkumar/anaconda2/lib/python2.7/site-packages/keras/engine/topology.pyc in stateful(self)
   2110     @property
   2111     def stateful(self):
-> 2112         return any([(hasattr(layer, 'stateful') and layer.stateful) for layer in self.layers])
   2113 
   2114     def reset_states(self):

KeyboardInterrupt: 

In [8]:
level=0
g_model = geom_model[level]
m_model = morph_model[level]
cg_model = geom_model[level]
cm_model = geom_model[level]
d_model = disc_model[level]
stacked_model = gan_model[level]


---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-8-35bb7e0460d5> in <module>()
      1 level=0
----> 2 g_model = geom_model[level]
      3 m_model = morph_model[level]
      4 cg_model = geom_model[level]
      5 cm_model = geom_model[level]

NameError: name 'geom_model' is not defined

In [ ]:
d_model.trainable = False
stacked_model.get_config()

In [ ]:
stacked_model.summary()

In [ ]:
noise_code = np.random.randn(1,1,100)

if rule == 'gmd':
    locations_gen = geom_model[0].predict(noise_code)
    softmax_gen = np.squeeze(cond_morph_model[0].predict([noise_code, locations_gen]))
elif rule == 'mgd':
    softmax_gen = morph_model[0].predict(noise_code)
    locations_gen = cond_geom_model[0].predict([noise_code, softmax_gen])
    softmax_gen = np.squeeze(softmax_gen)

In [ ]:
plt.imshow(softmax_gen, interpolation='none', cmap='Greys')
plt.colorbar()
plt.show()

In [ ]:
neuron_object = train.plot_example_neuron(locations_gen, softmax_gen)
neuron_object = McNeuron.Neuron(file_format='only list of nodes', input_file = neuron_object.nodes_list)

In [ ]:
plt.plot(softmax_gen.argmax(axis=1))
plt.ylim([0, 20])
plt.show()

In [ ]:
neuron_object.parent_index

In [ ]:
McNeuron.visualize.plot_dedrite_tree(neuron_object)

In [ ]:
print np.max(training_data['geometry']['n20'][0, :, :]), np.min(training_data['geometry']['n20'][0, :, :])
print locations_gen.max(), locations_gen.min()

In [ ]:
neuron_object = McNeuron.Neuron(file_format='only list of nodes', input_file = neuron_object.nodes_list)
neuron_object.fit()
features = neuron_object.features
features.keys()

In [ ]:
import pprint as pp
pp.pprint(features['branch_angle_segment'])

In [ ]:


In [ ]:
for ex in range(2200, 2300):
    input_code = dict()
    input_code['morphology'] = training_data['morphology']['n20'][ex, :]
    input_code['geometry'] = np.squeeze(training_data['geometry']['n20'][ex, :, :])
    neuron_object = data_transforms.make_swc_from_prufer_and_locations(input_code)
    neuron_object = McNeuron.Neuron(file_format='only list of nodes', input_file = neuron_object.nodes_list)
    McNeuron.visualize.plot_2D(neuron_object)
    McNeuron.visualize.plot_dedrite_tree(neuron_object)
    plt.show()
    plt.plot(training_data['morphology']['n20'][ex, :])
    plt.show()

In [ ]: