In [1]:
%load_ext autoreload
%autoreload 2

import cPickle as pickle
import os; import sys; sys.path.append('..')
import gp
import gp.nets as nets

from nolearn.lasagne.visualize import plot_loss
from nolearn.lasagne.visualize import plot_conv_weights
from nolearn.lasagne.visualize import plot_conv_activity
from nolearn.lasagne.visualize import plot_occlusion

from matplotlib.pyplot import imshow
import matplotlib.pyplot as plt
%matplotlib inline


/home/d/nolearn/local/lib/python2.7/site-packages/matplotlib/font_manager.py:273: UserWarning: Matplotlib is building the font cache using fc-list. This may take a moment.
  warnings.warn('Matplotlib is building the font cache using fc-list. This may take a moment.')
Using gpu device 0: GeForce GTX TITAN (CNMeM is disabled, CuDNN 4007)
/home/d/nolearn/local/lib/python2.7/site-packages/theano/tensor/signal/downsample.py:6: UserWarning: downsample module has been moved to the theano.tensor.signal.pool module.
  "downsample module has been moved to the theano.tensor.signal.pool module.")

In [3]:
PATCH_PATH = ('cylinder2_rgba_small')

In [4]:
X_train, y_train, X_test, y_test = gp.Patch.load_rgba(PATCH_PATH)


Loaded /home/d/patches//cylinder2_rgba_small/ in 0.16304898262 seconds.

In [22]:
gp.Util.view_rgba(X_train[100], y_train[100])



In [25]:
cnn = nets.RGBANet()


CNN configuration: 
    Our CNN with image, prob, merged_array and border overlap as RGBA.

    This includes dropout.
    

In [26]:
cnn = cnn.fit(X_train, y_train)


# Neural Network with 7134066 learnable parameters

## Layer information

  #  name      size
---  --------  --------
  0  input     4x75x75
  1  conv1     64x73x73
  2  pool1     64x36x36
  3  dropout1  64x36x36
  4  conv2     48x34x34
  5  pool2     48x17x17
  6  dropout2  48x17x17
  7  hidden3   512
  8  dropout3  512
  9  output    2

  epoch    train loss    valid loss    train/val    valid acc  dur
-------  ------------  ------------  -----------  -----------  ------
      1       0.29179       0.21075      1.38452      0.91803  25.29s
      2       0.19556       0.17478      1.11886      0.93354  25.28s
      3       0.16524       0.14738      1.12112      0.94464  25.27s
      4       0.13552       0.12874      1.05263      0.95317  25.35s
      5       0.10431       0.09559      1.09123      0.96613  25.33s
      6       0.07986       0.07920      1.00832      0.97178  25.30s
      7       0.06338       0.06350      0.99820      0.98082  25.37s
      8       0.04695       0.08050      0.58320      0.97457  25.48s
      9       0.03873       0.04434      0.87349      0.98740  25.45s
     10       0.03205       0.04062      0.78901      0.98731  25.45s
     11       0.02481       0.03864      0.64205      0.98983  25.46s
     12       0.01947       0.03795      0.51303      0.99072  25.46s
     13       0.01929       0.03571      0.54017      0.99018  25.46s
     14       0.01839       0.04099      0.44851      0.99097  25.44s
     15       0.01478       0.03803      0.38876      0.99126  25.44s
     16       0.01165       0.03516      0.33130      0.99234  25.44s
     17       0.01028       0.04292      0.23958      0.99126  25.45s
     18       0.00873       0.04098      0.21296      0.99126  25.45s
     19       0.00897       0.04412      0.20324      0.99171  25.46s
     20       0.01151       0.04099      0.28085      0.99153  25.45s
     21       0.00849       0.04151      0.20444      0.99162  25.45s
     22       0.00770       0.04296      0.17931      0.99132  25.45s
     23       0.00739       0.04930      0.14983      0.99180  25.45s
     24       0.00579       0.05017      0.11535      0.99097  25.44s
     25       0.00519       0.05995      0.08655      0.99090  25.44s
     26       0.00615       0.04067      0.15125      0.99207  25.45s
     27       0.00753       0.04245      0.17739      0.99315  25.45s
     28       0.00667       0.04175      0.15974      0.99279  25.43s
     29       0.00436       0.04009      0.10883      0.99279  25.43s
     30       0.00411       0.05633      0.07288      0.99189  25.43s
     31       0.00551       0.04092      0.13460      0.99369  25.43s
     32       0.00431       0.04565      0.09440      0.99171  25.44s
     33       0.00282       0.05117      0.05513      0.99153  25.44s
     34       0.00537       0.03950      0.13588      0.99261  25.45s
     35       0.00331       0.05039      0.06562      0.99243  25.44s

In [32]:
cnn = cnn.fit(X_train, y_train)


     36       0.00240       0.04912      0.04885      0.99189  25.31s
     37       0.00263       0.05617      0.04687      0.99162  25.25s
     38       0.00337       0.04676      0.07214      0.99225  25.36s
     39       0.00368       0.04230      0.08706      0.99225  25.45s
     40       0.00460       0.05154      0.08920      0.99171  25.46s
     41       0.00199       0.05178      0.03845      0.99153  25.46s
     42       0.00164       0.05414      0.03024      0.99117  25.46s
     43       0.00164       0.04746      0.03447      0.99279  25.45s
     44       0.00257       0.07061      0.03645      0.99016  25.45s
     45       0.00272       0.04788      0.05684      0.99261  25.47s
     46       0.00322       0.05044      0.06374      0.99225  25.45s
     47       0.00142       0.04915      0.02895      0.99189  25.46s
     48       0.00497       0.06113      0.08135      0.99180  25.45s
     49       0.00319       0.04843      0.06594      0.99279  25.45s
     50       0.00309       0.04497      0.06868      0.99243  25.46s
     51       0.00301       0.05209      0.05770      0.99189  25.45s
     52       0.00226       0.04459      0.05068      0.99297  25.46s
     53       0.00458       0.04665      0.09811      0.99225  25.46s
     54       0.00341       0.04442      0.07682      0.99261  25.46s
     55       0.00181       0.04980      0.03627      0.99279  25.46s
     56       0.00203       0.04549      0.04451      0.99099  25.46s
     57       0.00142       0.04496      0.03160      0.99261  25.47s
     58       0.00091       0.05111      0.01786      0.99225  25.46s
     59       0.00128       0.04837      0.02654      0.99279  25.46s
     60       0.00192       0.04707      0.04077      0.99351  25.46s
     61       0.00229       0.04862      0.04715      0.99207  25.45s
     62       0.00318       0.05552      0.05724      0.99333  25.50s
     63       0.00196       0.05522      0.03553      0.99279  25.46s
     64       0.00242       0.05030      0.04808      0.99261  25.46s
     65       0.00189       0.03927      0.04824      0.99297  25.47s
     66       0.00178       0.05818      0.03064      0.99261  25.46s
Early stopping.
Best valid loss was 0.035157 at epoch 16.
Loaded parameters to layer 'conv1' (shape 64x4x3x3).
Loaded parameters to layer 'conv1' (shape 64).
Loaded parameters to layer 'conv2' (shape 48x64x3x3).
Loaded parameters to layer 'conv2' (shape 48).
Loaded parameters to layer 'hidden3' (shape 13872x512).
Loaded parameters to layer 'hidden3' (shape 512).
Loaded parameters to layer 'output' (shape 512x2).
Loaded parameters to layer 'output' (shape 2).

In [33]:
test_accuracy = cnn.score(X_test, y_test)

In [34]:
test_accuracy


Out[34]:
0.9254780652418447

In [35]:
plot_loss(cnn)


Out[35]:
<module 'matplotlib.pyplot' from '/home/d/nolearn/local/lib/python2.7/site-packages/matplotlib/pyplot.pyc'>

In [46]:
plot_conv_weights(cnn.layers_['conv2'])


/home/d/nolearn/local/lib/python2.7/site-packages/matplotlib/pyplot.py:516: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`).
  max_open_warning, RuntimeWarning)
Out[46]:
<module 'matplotlib.pyplot' from '/home/d/nolearn/local/lib/python2.7/site-packages/matplotlib/pyplot.pyc'>

In [48]:
# store CNN
sys.setrecursionlimit(1000000000)
with open(os.path.expanduser('~/Projects/gp/nets/RGBA.p'), 'wb') as f:
  pickle.dump(cnn, f, -1)

In [5]:
with open(os.path.expanduser('~/Projects/gp/nets/RGBA.p'), 'rb') as f:
  net = pickle.load(f)

In [7]:
from sklearn.metrics import classification_report, accuracy_score, roc_curve, auc, precision_recall_fscore_support, f1_score, precision_recall_curve, average_precision_score, zero_one_loss

In [8]:
test_prediction = net.predict(X_test)
test_prediction_prob = net.predict_proba(X_test)
print
print 'Precision/Recall:'
print classification_report(y_test, test_prediction)


Precision/Recall:
             precision    recall  f1-score   support

          0       0.90      0.96      0.93      3556
          1       0.96      0.89      0.92      3556

avg / total       0.93      0.93      0.93      7112


In [ ]: