In [11]:
%load_ext autoreload
%autoreload 2

import cPickle as pickle
import os; import sys; sys.path.append('..')
import gp
import gp.nets as nets

from nolearn.lasagne.visualize import plot_loss
from nolearn.lasagne.visualize import plot_conv_weights
from nolearn.lasagne.visualize import plot_conv_activity
from nolearn.lasagne.visualize import plot_occlusion

from matplotlib.pyplot import imshow
import matplotlib.pyplot as plt
%matplotlib inline


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload

In [2]:
PATCH_PATH = ('cylinder2_rgb_small')

In [3]:
X_train, y_train, X_test, y_test = gp.Patch.load_rgb(PATCH_PATH)


Loaded /n/home05/haehn/patches_local//cylinder2_rgb_small/ in 0.0659239292145 seconds.

In [4]:
gp.Util.view_rgba(X_train[100], y_train[100])



In [5]:
cnn = nets.RGBNetPlusPlus()


CNN configuration: 
    Our CNN with image, prob, merged_array as RGB.

    This includes dropout. This also includes more layers. And more filters.
    

In [6]:
cnn = cnn.fit(X_train, y_train)


# Neural Network with 478498 learnable parameters

## Layer information

  #  name      size
---  --------  ---------
  0  input     3x75x75
  1  conv1     128x73x73
  2  pool1     128x36x36
  3  dropout1  128x36x36
  4  conv2     96x34x34
  5  pool2     96x17x17
  6  dropout2  96x17x17
  7  conv3     96x15x15
  8  pool3     96x7x7
  9  dropout3  96x7x7
 10  conv4     96x5x5
 11  pool4     96x2x2
 12  dropout4  96x2x2
 13  hidden5   512
 14  dropout5  512
 15  output    2

  epoch    trn loss    val loss    trn/val    valid acc  dur
-------  ----------  ----------  ---------  -----------  ------
      1     0.34724     0.23231    1.49473      0.91779  32.17s
      2     0.22757     0.21035    1.08186      0.92428  32.02s
      3     0.20607     0.19018    1.08355      0.93472  32.22s
      4     0.19355     0.18425    1.05046      0.93310  32.21s
      5     0.18282     0.17397    1.05089      0.93751  32.23s
      6     0.17313     0.16317    1.06108      0.94030  32.32s
      7     0.16548     0.15759    1.05011      0.94066  32.39s
      8     0.15674     0.15997    0.97984      0.94039  32.40s
      9     0.14793     0.14345    1.03122      0.94904  32.38s
     10     0.13865     0.14026    0.98856      0.95111  32.40s
     11     0.13161     0.13405    0.98182      0.95147  32.41s
     12     0.12057     0.11318    1.06527      0.96056  32.43s
     13     0.11597     0.11274    1.02865      0.96020  32.42s
     14     0.10810     0.11611    0.93102      0.95957  32.40s
     15     0.10192     0.09989    1.02037      0.96569  32.39s
     16     0.09848     0.10062    0.97874      0.96425  32.39s
     17     0.08961     0.08264    1.08429      0.97155  32.39s
     18     0.08743     0.07923    1.10352      0.97443  32.38s
     19     0.08011     0.08244    0.97176      0.97137  32.40s
     20     0.07804     0.07530    1.03640      0.97578  32.41s
     21     0.07157     0.07159    0.99966      0.97704  32.42s
     22     0.06809     0.06528    1.04304      0.98010  32.40s
     23     0.06396     0.06707    0.95361      0.97884  32.38s
     24     0.06306     0.05864    1.07531      0.98136  32.40s
     25     0.05751     0.05966    0.96398      0.98226  32.41s
     26     0.05598     0.05491    1.01939      0.98271  32.41s
     27     0.05071     0.04682    1.08318      0.98550  32.39s
     28     0.05017     0.04827    1.03941      0.98640  32.39s
     29     0.04527     0.05199    0.87063      0.98271  32.40s
     30     0.04504     0.04657    0.96714      0.98676  32.39s
     31     0.04129     0.05452    0.75735      0.98379  32.36s
     32     0.04446     0.06587    0.67504      0.97947  32.39s
     33     0.04079     0.04345    0.93877      0.98784  32.41s
     34     0.03881     0.04447    0.87280      0.98802  32.40s
     35     0.03681     0.04596    0.80095      0.98811  32.39s
     36     0.03619     0.04112    0.87994      0.98992  32.40s
     37     0.03652     0.03990    0.91537      0.98856  32.38s
     38     0.03290     0.03522    0.93422      0.99163  32.37s
     39     0.02806     0.04176    0.67201      0.98910  32.36s
     40     0.03131     0.04148    0.75480      0.98874  32.37s
     41     0.03328     0.04181    0.79581      0.98865  32.41s
     42     0.02961     0.04808    0.61587      0.98802  32.39s
     43     0.02939     0.04011    0.73277      0.99091  32.38s
     44     0.02883     0.04754    0.60642      0.98784  32.36s
     45     0.02720     0.03899    0.69776      0.99019  32.39s
     46     0.02920     0.04870    0.59960      0.98703  32.37s
     47     0.02411     0.04434    0.54379      0.99118  32.35s
     48     0.02550     0.04570    0.55798      0.98965  32.37s
     49     0.02351     0.04428    0.53083      0.99127  32.39s
     50     0.02107     0.04453    0.47319      0.99037  32.38s
     51     0.02455     0.04411    0.55660      0.99046  32.38s
     52     0.02341     0.03924    0.59644      0.99109  32.40s
     53     0.02093     0.04848    0.43165      0.99100  32.38s
     54     0.02215     0.04193    0.52818      0.99208  32.38s
     55     0.02144     0.04479    0.47873      0.99046  32.38s
     56     0.02137     0.04438    0.48150      0.99055  32.38s
     57     0.02066     0.04473    0.46197      0.99163  32.36s
     58     0.02037     0.04369    0.46619      0.99154  32.35s
     59     0.02103     0.04135    0.50863      0.99163  32.35s
     60     0.01889     0.04578    0.41268      0.99244  32.38s
     61     0.01587     0.04716    0.33645      0.99145  32.39s
     62     0.01901     0.04829    0.39362      0.99064  32.39s
     63     0.01790     0.04812    0.37198      0.99136  32.39s
     64     0.01752     0.04103    0.42692      0.99190  32.39s
     65     0.01800     0.04867    0.36993      0.99082  32.36s
     66     0.01783     0.04230    0.42149      0.99190  32.37s
     67     0.01577     0.04869    0.32386      0.99055  32.38s
     68     0.01398     0.04266    0.32771      0.99154  32.38s
     69     0.01623     0.04443    0.36527      0.99217  32.37s
     70     0.01481     0.04351    0.34031      0.99361  32.38s
     71     0.01420     0.04087    0.34748      0.99226  32.39s
     72     0.01672     0.04704    0.35542      0.99199  32.40s
     73     0.01449     0.04124    0.35131      0.99172  32.37s
     74     0.01217     0.04368    0.27869      0.99280  32.38s
     75     0.01523     0.04076    0.37361      0.99190  32.36s
     76     0.01515     0.04059    0.37330      0.99253  32.35s
     77     0.01576     0.04478    0.35197      0.99271  32.37s
     78     0.01179     0.04707    0.25049      0.99055  32.35s
     79     0.01695     0.05102    0.33226      0.99118  32.37s
     80     0.01235     0.05326    0.23195      0.99001  32.38s
     81     0.01329     0.04139    0.32112      0.99262  32.38s
     82     0.01202     0.05046    0.23817      0.99028  32.39s
     83     0.01159     0.05002    0.23170      0.99082  32.41s
     84     0.01368     0.04273    0.32008      0.99208  32.37s
     85     0.01349     0.04137    0.32603      0.99316  32.35s
     86     0.01227     0.05880    0.20872      0.99046  32.35s
     87     0.01076     0.04834    0.22251      0.99055  32.37s
     88     0.01076     0.04295    0.25061      0.99280  32.38s
Early stopping.
Best valid loss was 0.035216 at epoch 38.
Loaded parameters to layer 'conv1' (shape 128x3x3x3).
Loaded parameters to layer 'conv1' (shape 128).
Loaded parameters to layer 'conv2' (shape 96x128x3x3).
Loaded parameters to layer 'conv2' (shape 96).
Loaded parameters to layer 'conv3' (shape 96x96x3x3).
Loaded parameters to layer 'conv3' (shape 96).
Loaded parameters to layer 'conv4' (shape 96x96x3x3).
Loaded parameters to layer 'conv4' (shape 96).
Loaded parameters to layer 'hidden5' (shape 384x512).
Loaded parameters to layer 'hidden5' (shape 512).
Loaded parameters to layer 'output' (shape 512x2).
Loaded parameters to layer 'output' (shape 2).

In [7]:
test_accuracy = cnn.score(X_test, y_test)

In [8]:
test_accuracy


Out[8]:
0.9195725534308211

In [9]:
plot_loss(cnn)


Out[9]:
<module 'matplotlib.pyplot' from '/n/home05/haehn/nolearncox/lib/python2.7/site-packages/matplotlib-1.5.2-py2.7-linux-x86_64.egg/matplotlib/pyplot.pyc'>

In [12]:
# store CNN
sys.setrecursionlimit(1000000000)
with open(os.path.expanduser('~/Projects/gp/nets/RGBPlusPlus.p'), 'wb') as f:
  pickle.dump(cnn, f, -1)

In [ ]: