In [1]:
import os
import argparse
import numpy as np
import six
import six.moves.cPickle as pickle
import matplotlib.pyplot as plt
import cv2

import chainer
import chainer.functions as F
from chainer import cuda
from chainer import optimizers
#from chainer import computational_graph as c

In [2]:
%matplotlib inline

define parameters

parser = argparse.ArgumentParser(description='cnn for neuro classification') parser.add_argument('--gpu','-g',default=-1,type=int, help='GPU ID (negative value indicates CPU)') parser.add_argument('--arch','-a',default='cnn3') parser.add_argument('--batchs','-B',default=30,type=int, help = 'learning minibatch size') parser.add_argument('--epoch','-E',default=20, type=int) parser.add_argument('--out', '-o',default='model', help = "Path to save model") args = parser.parse_args()


In [3]:
n_epoch =200
batchsize=25
gpu_flag = 0

N = 900#size of train data

Get image dataset


In [4]:
dpath=os.path.abspath("")
foldername = "/dataset"

x_data = np.load(dpath+foldername+"/x_data1.npy")
t_data = np.load(dpath+foldername+"/t_data1.npy")
mean = np.load(dpath+foldername+"/mean.npy")

In [5]:
x_data = x_data - mean[:,:,:,np.newaxis]

x_data = x_data.transpose(3,0,1,2).astype(np.float32)
t_data = t_data.transpose(3,0,1,2).astype(np.float32)

print x_data.shape
print t_data.shape
print len(x_data)

N_test = len(x_data)-N

print N_test


(1000, 1, 255, 255)
(1000, 1, 255, 255)
1000
100

In [6]:
np.min(x_data)


Out[6]:
-0.41072941

Separate data into train and test


In [7]:
#shuffle the data
shuffle = np.random.permutation(len(t_data))

x_data = np.asarray(x_data[shuffle,:,:,:])
t_data = np.asarray(t_data[shuffle,:,:,:])

In [8]:
#separate data
x_train, x_test = np.split(x_data,[N],axis=0)
t_train, t_test = np.split(t_data,[N],axis=0)

print x_train.shape, t_train.shape


(900, 1, 255, 255) (900, 1, 255, 255)

Prepare model


In [28]:
import segmentation
model = segmentation.CNN_segment3()

GPU setup


In [29]:
if gpu_flag >= 0:
    cuda.check_cuda_available()

xp = cuda.cupy if gpu_flag >=0 else np
    
if gpu_flag >= 0:
    cuda.get_device(gpu_flag).use()
    model.to_gpu()

Setup optimizer(Adam)


In [30]:
optimizer = optimizers.Adam()
optimizer.setup(model)

Learning loop


In [31]:
train_loss = []
test_loss = []

for epoch in six.moves.range(1,n_epoch +1):
    print 'epoch: ', epoch
    
    #training
    perm = np.random.permutation(N)
    sum_loss = 0
    for i in six.moves.range(0,N,batchsize):
        x_batch = xp.asarray(x_train[perm[i:i + batchsize]])
        t_batch = xp.asarray(t_train[perm[i:i + batchsize]])
        
        optimizer.zero_grads()
        loss = model.forward(x_batch, t_batch)
        loss.backward()
        optimizer.update()
        
#        if epoch ==1 and i ==0:
#            with open("graph.dot","w") as o:
#                o.write(c.build_computational_graph((loss, )).dump())
#            with open("graph.wo_split.dot","w") as o:
#                g = c.bulid_computational_graph((loss, ), remove_split=True)
#                o.write(g.dump())
                
#            print "graph generated"
        
        #print loss.data


        sum_loss += float(loss.data)*t_batch.shape[0]

        
    train_loss.append([epoch,sum_loss/N])
    
    print "train mean loss={}".format(sum_loss/N)
    
    #evaluation
    sum_loss = 0
    for i in six.moves.range(0,N_test,batchsize):
        x_batch = xp.asarray(x_test[i:i+batchsize])
        t_batch = xp.asarray(t_test[i:i+batchsize])
        
        loss = model.forward(x_batch, t_batch, train=False)
        
        sum_loss += float(loss.data)*t_batch.shape[0]
        
    test_loss.append([epoch,sum_loss/N_test])

    print "test mean loss={}".format(sum_loss/N_test)
    
train_loss = np.asarray(train_loss)
test_loss = np.asarray(test_loss)


epoch:  1
train mean loss=0.206677246011
test mean loss=0.228950701654
epoch:  2
train mean loss=0.201947336396
test mean loss=0.232668474317
epoch:  3
train mean loss=0.202099508709
test mean loss=0.233494330198
epoch:  4
train mean loss=0.201753888279
test mean loss=0.233430594206
epoch:  5
train mean loss=0.20131021945
test mean loss=0.23068626225
epoch:  6
train mean loss=0.201312883033
test mean loss=0.228551656008
epoch:  7
train mean loss=0.201315275911
test mean loss=0.231829363853
epoch:  8
train mean loss=0.201364608275
test mean loss=0.229205921292
epoch:  9
train mean loss=0.201162660701
test mean loss=0.230182752013
epoch:  10
train mean loss=0.201136118836
test mean loss=0.230457291007
epoch:  11
train mean loss=0.201177443481
test mean loss=0.229999888688
epoch:  12
train mean loss=0.201582070026
test mean loss=0.234066523612
epoch:  13
train mean loss=0.201351177361
test mean loss=0.229334395379
epoch:  14
train mean loss=0.201126277861
test mean loss=0.229762297124
epoch:  15
train mean loss=0.200590325726
test mean loss=0.221810728312
epoch:  16
train mean loss=0.143078546143
test mean loss=0.133366502821
epoch:  17
train mean loss=0.104190074528
test mean loss=0.0952019523829
epoch:  18
train mean loss=0.0822263407624
test mean loss=0.0824556108564
epoch:  19
train mean loss=0.0733784145365
test mean loss=0.0742314839736
epoch:  20
train mean loss=0.0669299227496
test mean loss=0.076927873306
epoch:  21
train mean loss=0.0643793510066
test mean loss=0.0669555794448
epoch:  22
train mean loss=0.0612824286024
test mean loss=0.069050559774
epoch:  23
train mean loss=0.0576645399754
test mean loss=0.0622087083757
epoch:  24
train mean loss=0.0538818311567
test mean loss=0.0595470983535
epoch:  25
train mean loss=0.047099839482
test mean loss=0.0469946991652
epoch:  26
train mean loss=0.046780786477
test mean loss=0.046920241788
epoch:  27
train mean loss=0.0446559016903
test mean loss=0.0469967527315
epoch:  28
train mean loss=0.0422575210945
test mean loss=0.0435978220776
epoch:  29
train mean loss=0.0422314523409
test mean loss=0.0435176156461
epoch:  30
train mean loss=0.0397109365505
test mean loss=0.0432093245909
epoch:  31
train mean loss=0.0398537928445
test mean loss=0.0423762528226
epoch:  32
train mean loss=0.0403582407162
test mean loss=0.0418273620307
epoch:  33
train mean loss=0.0386848288795
test mean loss=0.0401472151279
epoch:  34
train mean loss=0.0379213005718
test mean loss=0.0401535350829
epoch:  35
train mean loss=0.0368844433688
test mean loss=0.0391914788634
epoch:  36
train mean loss=0.0372499938951
test mean loss=0.0389050394297
epoch:  37
train mean loss=0.0366763009483
test mean loss=0.040657447651
epoch:  38
train mean loss=0.0367701537907
test mean loss=0.0384933110327
epoch:  39
train mean loss=0.0347815720985
test mean loss=0.0375592056662
epoch:  40
train mean loss=0.0343150352645
test mean loss=0.0369010139257
epoch:  41
train mean loss=0.0348246572022
test mean loss=0.036966486834
epoch:  42
train mean loss=0.0337891901015
test mean loss=0.0370500693098
epoch:  43
train mean loss=0.0329461192402
test mean loss=0.0357038117945
epoch:  44
train mean loss=0.0324724273653
test mean loss=0.0360838025808
epoch:  45
train mean loss=0.0322457302051
test mean loss=0.0351647501811
epoch:  46
train mean loss=0.0328451858626
test mean loss=0.033919095993
epoch:  47
train mean loss=0.0306500575712
test mean loss=0.0323217641562
epoch:  48
train mean loss=0.0301175189929
test mean loss=0.0342939170077
epoch:  49
train mean loss=0.0298991836607
test mean loss=0.0330121470615
epoch:  50
train mean loss=0.0307385059487
test mean loss=0.0316237299703
epoch:  51
train mean loss=0.0288758189935
test mean loss=0.033581124153
epoch:  52
train mean loss=0.0321486288578
test mean loss=0.0311182602309
epoch:  53
train mean loss=0.0285555775691
test mean loss=0.0304270270281
epoch:  54
train mean loss=0.0283211427223
test mean loss=0.0305521842092
epoch:  55
train mean loss=0.0275067605803
test mean loss=0.0289418236353
epoch:  56
train mean loss=0.0272491247807
test mean loss=0.0287932264619
epoch:  57
train mean loss=0.027057814774
test mean loss=0.0305441035889
epoch:  58
train mean loss=0.0263485063074
test mean loss=0.0284247309901
epoch:  59
train mean loss=0.0274482205924
test mean loss=0.02830545092
epoch:  60
train mean loss=0.0257329224195
test mean loss=0.0278501501307
epoch:  61
train mean loss=0.0266516280567
test mean loss=0.0289295916446
epoch:  62
train mean loss=0.0260383485713
test mean loss=0.0273826192133
epoch:  63
train mean loss=0.0255943039018
test mean loss=0.0275288932025
epoch:  64
train mean loss=0.0251304923246
test mean loss=0.0266703893431
epoch:  65
train mean loss=0.0245036441936
test mean loss=0.0268567395397
epoch:  66
train mean loss=0.0243813675932
test mean loss=0.0261453534476
epoch:  67
train mean loss=0.0248390143323
test mean loss=0.0274124559946
epoch:  68
train mean loss=0.0243636186028
test mean loss=0.0265530869365
epoch:  69
train mean loss=0.0236252692218
test mean loss=0.0266521628946
epoch:  70
train mean loss=0.0238457756738
test mean loss=0.0259001571685
epoch:  71
train mean loss=0.0238352227542
test mean loss=0.0260332967155
epoch:  72
train mean loss=0.0232218200237
test mean loss=0.025124442298
epoch:  73
train mean loss=0.0228371271967
test mean loss=0.0259036482312
epoch:  74
train mean loss=0.0231880304766
test mean loss=0.0265773572028
epoch:  75
train mean loss=0.0230943008533
test mean loss=0.0254131807014
epoch:  76
train mean loss=0.022415973867
test mean loss=0.0242331656627
epoch:  77
train mean loss=0.0221647322178
test mean loss=0.0245547466911
epoch:  78
train mean loss=0.0219350903709
test mean loss=0.0241855839267
epoch:  79
train mean loss=0.0238864127443
test mean loss=0.0244333064184
epoch:  80
train mean loss=0.022366863365
test mean loss=0.0243595470674
epoch:  81
train mean loss=0.0215113426352
test mean loss=0.0234177405946
epoch:  82
train mean loss=0.0213365492721
test mean loss=0.0234494446777
epoch:  83
train mean loss=0.0211869727872
test mean loss=0.0232970151119
epoch:  84
train mean loss=0.0209611813124
test mean loss=0.0236502252519
epoch:  85
train mean loss=0.0207111954482
test mean loss=0.0232494883239
epoch:  86
train mean loss=0.0210265043812
test mean loss=0.0225792219862
epoch:  87
train mean loss=0.0204238992754
test mean loss=0.0227606277913
epoch:  88
train mean loss=0.0204414384595
test mean loss=0.023008194752
epoch:  89
train mean loss=0.0202117541598
test mean loss=0.0223840316758
epoch:  90
train mean loss=0.0199102298874
test mean loss=0.0228005624376
epoch:  91
train mean loss=0.0204007328591
test mean loss=0.0224753529765
epoch:  92
train mean loss=0.0204533420555
test mean loss=0.0227065724321
epoch:  93
train mean loss=0.0283634038125
test mean loss=0.0268604476005
epoch:  94
train mean loss=0.0227651700067
test mean loss=0.0240204334259
epoch:  95
train mean loss=0.0202689527844
test mean loss=0.02283362858
epoch:  96
train mean loss=0.0198346519222
test mean loss=0.0217878799886
epoch:  97
train mean loss=0.0194653567548
test mean loss=0.0217227512039
epoch:  98
train mean loss=0.0192060421428
test mean loss=0.0225362083875
epoch:  99
train mean loss=0.0194091142621
test mean loss=0.0230036522262
epoch:  100
train mean loss=0.0189752629441
test mean loss=0.0214040591381
epoch:  101
train mean loss=0.0188666953602
test mean loss=0.0222549885511
epoch:  102
train mean loss=0.0185890884087
test mean loss=0.0212241769768
epoch:  103
train mean loss=0.0184621673284
test mean loss=0.0207211463712
epoch:  104
train mean loss=0.0184243015376
test mean loss=0.0214123846963
epoch:  105
train mean loss=0.0185223872152
test mean loss=0.0207401379012
epoch:  106
train mean loss=0.0183461625905
test mean loss=0.0210442687385
epoch:  107
train mean loss=0.0180224548353
test mean loss=0.0208351304755
epoch:  108
train mean loss=0.0183529301236
test mean loss=0.0213002352975
epoch:  109
train mean loss=0.0181926246215
test mean loss=0.0210190271027
epoch:  110
train mean loss=0.0181777220375
test mean loss=0.0202140081674
epoch:  111
train mean loss=0.0177141333309
test mean loss=0.0205419915728
epoch:  112
train mean loss=0.0176406391224
test mean loss=0.0201616650447
epoch:  113
train mean loss=0.0175319604783
test mean loss=0.0205100276507
epoch:  114
train mean loss=0.0172117874026
test mean loss=0.0197305176407
epoch:  115
train mean loss=0.0178009552053
test mean loss=0.020227282308
epoch:  116
train mean loss=0.0172666256419
test mean loss=0.0198133816011
epoch:  117
train mean loss=0.0171082887747
test mean loss=0.019744975958
epoch:  118
train mean loss=0.0170119391082
test mean loss=0.0194781813771
epoch:  119
train mean loss=0.0168585452355
test mean loss=0.0196764646098
epoch:  120
train mean loss=0.0167119534065
test mean loss=0.0210060221143
epoch:  121
train mean loss=0.0167229306988
test mean loss=0.0198572780937
epoch:  122
train mean loss=0.0164335338244
test mean loss=0.0196756501682
epoch:  123
train mean loss=0.016831631617
test mean loss=0.0191309549846
epoch:  124
train mean loss=0.0163323900455
test mean loss=0.0200268412009
epoch:  125
train mean loss=0.0161333904964
test mean loss=0.0191271975636
epoch:  126
train mean loss=0.0159427191959
test mean loss=0.0192780666985
epoch:  127
train mean loss=0.0165076063325
test mean loss=0.0199507349171
epoch:  128
train mean loss=0.0164323223289
test mean loss=0.0189176290296
epoch:  129
train mean loss=0.0159513610415
test mean loss=0.018770063296
epoch:  130
train mean loss=0.0160358240052
test mean loss=0.0199493775144
epoch:  131
train mean loss=0.0156115398535
test mean loss=0.0187151418068
epoch:  132
train mean loss=0.0160898741386
test mean loss=0.0187093387358
epoch:  133
train mean loss=0.0158620996711
test mean loss=0.019188163802
epoch:  134
train mean loss=0.0154710852593
test mean loss=0.0189778134227
epoch:  135
train mean loss=0.0152893142723
test mean loss=0.0188150047325
epoch:  136
train mean loss=0.0158203100372
test mean loss=0.0186368403956
epoch:  137
train mean loss=0.0150516874209
test mean loss=0.0187016869895
epoch:  138
train mean loss=0.0151161510083
test mean loss=0.0189283369109
epoch:  139
train mean loss=0.0150311419016
test mean loss=0.0183900070842
epoch:  140
train mean loss=0.0152743471165
test mean loss=0.0182519885711
epoch:  141
train mean loss=0.0154347377829
test mean loss=0.0195668921806
epoch:  142
train mean loss=0.0155557050214
test mean loss=0.0180555176921
epoch:  143
train mean loss=0.0147315712594
test mean loss=0.0182200388517
epoch:  144
train mean loss=0.0149050309685
test mean loss=0.0184744501021
epoch:  145
train mean loss=0.0147593450577
test mean loss=0.0178669744637
epoch:  146
train mean loss=0.0144108704084
test mean loss=0.0177970170043
epoch:  147
train mean loss=0.014486595652
test mean loss=0.0177030381747
epoch:  148
train mean loss=0.0148994092985
test mean loss=0.0193656967022
epoch:  149
train mean loss=0.0151375442899
test mean loss=0.0181356922258
epoch:  150
train mean loss=0.0147933258365
test mean loss=0.0176573721692
epoch:  151
train mean loss=0.0141364627828
test mean loss=0.0178797622211
epoch:  152
train mean loss=0.0138949060606
test mean loss=0.0175818549469
epoch:  153
train mean loss=0.0138014262904
test mean loss=0.0175356608815
epoch:  154
train mean loss=0.0139704848019
test mean loss=0.0184059950989
epoch:  155
train mean loss=0.0140183087852
test mean loss=0.0175037120935
epoch:  156
train mean loss=0.0134656178351
test mean loss=0.0173712915275
epoch:  157
train mean loss=0.0136194312686
test mean loss=0.0175589621067
epoch:  158
train mean loss=0.0135069787761
test mean loss=0.0170980447438
epoch:  159
train mean loss=0.0134644996271
test mean loss=0.0173705276102
epoch:  160
train mean loss=0.0132751446735
test mean loss=0.0174023851287
epoch:  161
train mean loss=0.0150856605421
test mean loss=0.0196877552662
epoch:  162
train mean loss=0.01553606868
test mean loss=0.0175515124574
epoch:  163
train mean loss=0.0132303107271
test mean loss=0.0169035803992
epoch:  164
train mean loss=0.0129619474368
test mean loss=0.0169444142375
epoch:  165
train mean loss=0.0128885797587
test mean loss=0.0175147734117
epoch:  166
train mean loss=0.0127889480338
test mean loss=0.0168947647326
epoch:  167
train mean loss=0.0129228620014
test mean loss=0.0171181070618
epoch:  168
train mean loss=0.012805651744
test mean loss=0.0171523219906
epoch:  169
train mean loss=0.0129688888943
test mean loss=0.0168327412102
epoch:  170
train mean loss=0.0128355889788
test mean loss=0.0168906385079
epoch:  171
train mean loss=0.0127647655706
test mean loss=0.0170788578689
epoch:  172
train mean loss=0.0126207859204
test mean loss=0.0168128409423
epoch:  173
train mean loss=0.0124804021584
test mean loss=0.0165929999202
epoch:  174
train mean loss=0.0123788633436
test mean loss=0.0166446636431
epoch:  175
train mean loss=0.0124625456519
test mean loss=0.0164837883785
epoch:  176
train mean loss=0.0121379908847
test mean loss=0.0167113407515
epoch:  177
train mean loss=0.0122176552864
test mean loss=0.0176752167754
epoch:  178
train mean loss=0.0125950837974
test mean loss=0.0163622286636
epoch:  179
train mean loss=0.0120957863207
test mean loss=0.0171133247204
epoch:  180
train mean loss=0.0120910127492
test mean loss=0.0165374001954
epoch:  181
train mean loss=0.0120640969318
test mean loss=0.0173208210617
epoch:  182
train mean loss=0.0120251612583
test mean loss=0.0163645492867
epoch:  183
train mean loss=0.0118056316601
test mean loss=0.0162279875949
epoch:  184
train mean loss=0.0118015291325
test mean loss=0.0162295522168
epoch:  185
train mean loss=0.0117308935927
test mean loss=0.0163251243066
epoch:  186
train mean loss=0.0117175743605
test mean loss=0.0160972527228
epoch:  187
train mean loss=0.0118718983916
test mean loss=0.0167089093011
epoch:  188
train mean loss=0.012078937764
test mean loss=0.0163892833516
epoch:  189
train mean loss=0.0115139968693
test mean loss=0.0161944343708
epoch:  190
train mean loss=0.0118418190266
test mean loss=0.0168527427595
epoch:  191
train mean loss=0.011922905144
test mean loss=0.0164730516262
epoch:  192
train mean loss=0.0114416922443
test mean loss=0.0159411458299
epoch:  193
train mean loss=0.0112919737585
test mean loss=0.0159709269647
epoch:  194
train mean loss=0.0112666312036
test mean loss=0.0164770428091
epoch:  195
train mean loss=0.0113828774128
test mean loss=0.0160424578935
epoch:  196
train mean loss=0.0110584128415
test mean loss=0.0157757503912
epoch:  197
train mean loss=0.0114673835277
test mean loss=0.0159407143947
epoch:  198
train mean loss=0.0112063458396
test mean loss=0.0161040162202
epoch:  199
train mean loss=0.0109570801336
test mean loss=0.0156895471737
epoch:  200
train mean loss=0.0110415421788
test mean loss=0.0156811249908

Plot&save graph


In [32]:
fig, ax1 = plt.subplots()
ax1.plot(train_loss[:, 0], train_loss[:, 1], label='training loss')
ax1.plot(test_loss[:, 0], test_loss[:, 1], label='test loss')
ax1.set_xlim([1, len(train_loss)])
ax1.set_xlabel('epoch')
ax1.set_ylabel('loss')


ax1.legend(bbox_to_anchor=(0.25, -0.1), loc=9)
plt.title("6-layer(encoder3+decoder3) Convolutional sigmoid segmantation net")
plt.savefig(dpath+"/figs/segmentation3_sig", bbox_inches='tight')

plt.show()


Save final model


In [33]:
pickle.dump(model, open('sig_segment3', 'wb'),-1)

Test segmentration


In [34]:
model='sig_segment3'
model=pickle.load(open(model,'rb'))

In [35]:
num=28

raw data


In [36]:
plt.imshow(x_test[num,0,:],cmap = plt.get_cmap('gray'))


Out[36]:
<matplotlib.image.AxesImage at 0x7f967595da10>

ground truth


In [37]:
plt.imshow(t_test[num,0,:],cmap = plt.get_cmap('gray'))


Out[37]:
<matplotlib.image.AxesImage at 0x7f9675886ad0>

Predictied map


In [38]:
predict = np.array([]).astype(np.float32) 
for i in xrange(len(x_test)):
    hoge = xp.asarray(x_test[i:i+1])
    hoge=cuda.to_cpu(model.predict(hoge))
    hoge=np.array(hoge)
    if i==0:
        predict = hoge
    else:
        predict = np.concatenate((predict, hoge),axis=0)

In [39]:
plt.imshow(predict[num,0,:],cmap = plt.get_cmap('gray'))


Out[39]:
<matplotlib.image.AxesImage at 0x7f9674f80610>

Save result


In [40]:
np.save(dpath+"/result/rawimage_sig3.npy",x_test)
np.save(dpath+"/result/groundtruth_sig3.npy",t_test)
np.save(dpath+"/result/predict_sig3.npy",predict)

In [ ]: