In [1]:
%load_ext autoreload
%autoreload 2
import cPickle as pickle
import os; import sys; sys.path.append('..')
import gp
import gp.nets as nets
from nolearn.lasagne.visualize import plot_loss
from nolearn.lasagne.visualize import plot_conv_weights
from nolearn.lasagne.visualize import plot_conv_activity
from nolearn.lasagne.visualize import plot_occlusion
from matplotlib.pyplot import imshow
import matplotlib.pyplot as plt
%matplotlib inline
Using gpu device 0: GeForce GTX TITAN (CNMeM is disabled, CuDNN 4007)
/home/d/nolearn/local/lib/python2.7/site-packages/theano/tensor/signal/downsample.py:6: UserWarning: downsample module has been moved to the theano.tensor.signal.pool module.
"downsample module has been moved to the theano.tensor.signal.pool module.")
In [2]:
PATCH_PATH = ('cylinder2_rgba_small')
In [3]:
X_train, y_train, X_test, y_test = gp.Patch.load_rgba(PATCH_PATH)
Loaded /home/d/patches//cylinder2_rgba_small/ in 0.152738809586 seconds.
In [4]:
gp.Util.view_rgba(X_train[100], y_train[100])
In [7]:
cnn = nets.RGBANetPlusPlus()
CNN configuration:
Our CNN with image, prob, merged_array and border overlap as RGBA.
This includes dropout. This also includes more layers. And more filters.
In [8]:
cnn = cnn.fit(X_train, y_train)
# Neural Network with 479650 learnable parameters
## Layer information
# name size
--- -------- ---------
0 input 4x75x75
1 conv1 128x73x73
2 pool1 128x36x36
3 dropout1 128x36x36
4 conv2 96x34x34
5 pool2 96x17x17
6 dropout2 96x17x17
7 conv3 96x15x15
8 pool3 96x7x7
9 dropout3 96x7x7
10 conv4 96x5x5
11 pool4 96x2x2
12 dropout4 96x2x2
13 hidden5 512
14 dropout5 512
15 output 2
epoch train loss valid loss train/val valid acc dur
------- ------------ ------------ ----------- ----------- ------
1 0.33431 0.22689 1.47346 0.91620 48.19s
2 0.21704 0.19538 1.11083 0.92938 47.97s
3 0.20027 0.18151 1.10331 0.93111 48.42s
4 0.18937 0.17595 1.07624 0.93617 48.42s
5 0.17883 0.16980 1.05318 0.93839 48.46s
6 0.17106 0.16112 1.06175 0.94218 49.33s
7 0.16242 0.15207 1.06810 0.94503 49.34s
8 0.15390 0.15342 1.00313 0.94647 49.39s
9 0.14533 0.13853 1.04911 0.94769 49.49s
10 0.13931 0.12539 1.11098 0.95439 49.81s
11 0.13162 0.12577 1.04652 0.95353 49.96s
12 0.12593 0.12388 1.01659 0.95685 49.84s
13 0.11831 0.10363 1.14169 0.96403 49.75s
14 0.10899 0.10540 1.03409 0.96292 49.78s
15 0.10309 0.10306 1.00035 0.96438 49.78s
16 0.09788 0.09109 1.07456 0.96852 49.67s
17 0.09398 0.08441 1.11347 0.97008 49.66s
18 0.08779 0.08150 1.07715 0.97399 49.97s
19 0.08164 0.07443 1.09678 0.97408 50.18s
20 0.07852 0.07580 1.03597 0.97205 50.17s
21 0.07580 0.06618 1.14539 0.97684 50.18s
22 0.06998 0.06041 1.15851 0.97998 50.23s
23 0.06659 0.05958 1.11758 0.98120 50.19s
24 0.06583 0.05669 1.16121 0.98273 50.18s
25 0.06293 0.05255 1.19759 0.98351 50.17s
26 0.05770 0.05111 1.12878 0.98399 50.21s
27 0.05203 0.04877 1.06691 0.98560 50.19s
28 0.05347 0.05067 1.05535 0.98423 50.20s
29 0.05104 0.04848 1.05276 0.98297 50.19s
30 0.04694 0.04444 1.05622 0.98731 50.19s
31 0.04681 0.04111 1.13876 0.98776 50.21s
32 0.04323 0.04366 0.99031 0.98767 50.20s
33 0.04248 0.03791 1.12047 0.98839 50.17s
34 0.04003 0.03544 1.12947 0.99000 50.21s
35 0.03658 0.03883 0.94208 0.98977 50.18s
36 0.04040 0.03458 1.16819 0.99198 50.19s
37 0.03455 0.03425 1.00865 0.99141 50.22s
38 0.03708 0.02946 1.25855 0.99297 50.18s
39 0.03378 0.03351 1.00821 0.99099 50.21s
40 0.03260 0.03149 1.03531 0.99252 50.19s
41 0.03191 0.03075 1.03775 0.99234 50.20s
42 0.02732 0.03014 0.90620 0.99261 50.22s
43 0.03092 0.03078 1.00469 0.99195 50.18s
44 0.02703 0.03389 0.79767 0.99189 50.18s
45 0.02638 0.03587 0.73549 0.99252 50.18s
46 0.02591 0.03059 0.84689 0.99342 50.22s
47 0.02248 0.03434 0.65468 0.99189 50.20s
48 0.02364 0.03853 0.61355 0.99000 50.18s
49 0.02166 0.03317 0.65298 0.99180 50.17s
50 0.02283 0.03327 0.68623 0.99288 50.18s
51 0.02210 0.03133 0.70520 0.99252 50.17s
52 0.02453 0.02811 0.87249 0.99285 50.19s
53 0.02017 0.03223 0.62579 0.99252 50.20s
54 0.02638 0.03012 0.87575 0.99378 50.14s
55 0.01967 0.02571 0.76501 0.99414 50.10s
56 0.01974 0.03251 0.60722 0.99351 50.13s
57 0.01801 0.02944 0.61162 0.99279 50.13s
58 0.02112 0.03566 0.59211 0.99270 50.14s
59 0.02049 0.04467 0.45867 0.99153 50.14s
60 0.01742 0.03600 0.48373 0.99315 50.13s
61 0.01889 0.03262 0.57897 0.99360 50.13s
62 0.01871 0.03760 0.49753 0.99279 50.10s
63 0.01777 0.03799 0.46781 0.99288 50.11s
64 0.01842 0.03367 0.54714 0.99261 50.12s
65 0.01607 0.03482 0.46142 0.99324 50.11s
66 0.01785 0.03465 0.51505 0.99234 50.12s
67 0.01684 0.03043 0.55332 0.99432 50.13s
68 0.01464 0.03188 0.45914 0.99387 50.14s
69 0.01753 0.02925 0.59934 0.99270 50.13s
70 0.01324 0.03369 0.39288 0.99330 50.12s
71 0.01463 0.03610 0.40524 0.99249 50.14s
72 0.01611 0.03519 0.45771 0.99414 50.13s
73 0.01498 0.02922 0.51254 0.99432 50.13s
74 0.01388 0.03822 0.36325 0.99324 50.11s
75 0.01500 0.03399 0.44128 0.99297 50.19s
76 0.01567 0.04126 0.37978 0.99252 50.13s
77 0.01425 0.04064 0.35056 0.99261 50.13s
78 0.01304 0.03837 0.33995 0.99315 50.12s
79 0.01329 0.03101 0.42872 0.99414 50.14s
80 0.01046 0.03631 0.28822 0.99432 50.11s
81 0.01391 0.03964 0.35093 0.99270 50.13s
82 0.01049 0.03244 0.32340 0.99414 50.12s
83 0.01095 0.03935 0.27833 0.99270 50.11s
84 0.01364 0.03110 0.43842 0.99414 50.12s
85 0.01352 0.04523 0.29887 0.99243 50.10s
86 0.01209 0.04643 0.26039 0.99171 50.08s
87 0.01244 0.03370 0.36924 0.99432 50.28s
88 0.01256 0.04060 0.30934 0.99234 50.28s
89 0.01156 0.02951 0.39165 0.99378 50.39s
90 0.01306 0.03555 0.36741 0.99396 50.23s
91 0.01086 0.03781 0.28710 0.99285 50.25s
92 0.01110 0.03192 0.34772 0.99548 50.29s
93 0.01113 0.02933 0.37936 0.99503 50.36s
94 0.01202 0.03787 0.31733 0.99360 50.41s
95 0.01068 0.03624 0.29464 0.99467 50.38s
96 0.00829 0.03616 0.22912 0.99432 50.37s
97 0.00909 0.03199 0.28410 0.99503 50.15s
98 0.01131 0.03966 0.28525 0.99342 50.13s
99 0.00904 0.03635 0.24867 0.99449 50.13s
100 0.01015 0.03425 0.29642 0.99369 50.13s
101 0.01003 0.03207 0.31276 0.99458 50.17s
102 0.00952 0.04026 0.23651 0.99243 50.12s
103 0.01126 0.04044 0.27844 0.99333 50.16s
104 0.00965 0.03784 0.25510 0.99369 50.11s
105 0.00926 0.03443 0.26889 0.99423 50.20s
Early stopping.
Best valid loss was 0.025711 at epoch 55.
Loaded parameters to layer 'conv1' (shape 128x4x3x3).
Loaded parameters to layer 'conv1' (shape 128).
Loaded parameters to layer 'conv2' (shape 96x128x3x3).
Loaded parameters to layer 'conv2' (shape 96).
Loaded parameters to layer 'conv3' (shape 96x96x3x3).
Loaded parameters to layer 'conv3' (shape 96).
Loaded parameters to layer 'conv4' (shape 96x96x3x3).
Loaded parameters to layer 'conv4' (shape 96).
Loaded parameters to layer 'hidden5' (shape 384x512).
Loaded parameters to layer 'hidden5' (shape 512).
Loaded parameters to layer 'output' (shape 512x2).
Loaded parameters to layer 'output' (shape 2).
In [13]:
test_accuracy = cnn.score(X_test, y_test)
In [14]:
test_accuracy
Out[14]:
0.9285714285714286
In [15]:
plot_loss(cnn)
Out[15]:
<module 'matplotlib.pyplot' from '/home/d/nolearn/local/lib/python2.7/site-packages/matplotlib/pyplot.pyc'>
In [12]:
# store CNN
sys.setrecursionlimit(1000000000)
with open(os.path.expanduser('~/Projects/gp/nets/RGBAPlusPlus.p'), 'wb') as f:
pickle.dump(cnn, f, -1)
In [ ]:
In [ ]:
Content source: VCG/gp
Similar notebooks: