mostly based on: https://github.com/Lasagne/Recipes/blob/master/examples/spatial_transformer_network.ipynb
In [1]:
# most of the code is kept in helpers, to make it easier to grep for and change
# than in an ipython notebook
%matplotlib inline
import helpers
import matplotlib.pyplot as plt
Couldn't import dot_parser, loading of dot files will not be possible.
Using gpu device 0: GeForce GTX 980 (CNMeM is enabled)
In [2]:
%%time
# load cluttered MNIST data (41MB)
in_train, in_valid, in_test = helpers.load_data()
Train samples: (50000, 1, 60, 60)
Validation samples: (10000, 1, 60, 60)
Test samples: (10000, 1, 60, 60)
CPU times: user 3.02 s, sys: 157 ms, total: 3.18 s
Wall time: 3.56 s
In [3]:
# example image
plt.figure(figsize=(7, 7))
plt.imshow(in_train["x"][101, 0], cmap='gray', interpolation='none')
plt.title('Cluttered MNIST', fontsize=20)
plt.axis('off')
plt.show()
In [12]:
%%time
# loading a network
network = helpers.load_network(update_scale_factor=0.5)
print network.root_node
HyperparameterNode(name=u'with_updates', learning_rate=0.002, cost_function=<function categorical_crossentropy_i32 at 0x7fbeb81675f0>)
| AdamNode(name=u'adam')
| | TotalCostNode(name=u'cost')
| | | ReferenceNode(name=u'pred_ref', reference=u'model')
| | | InputNode(name=u'y', dtype=u'int32', shape=(None,))
| | HyperparameterNode(name=u'model', bn_update_moving_stats=True, pool_size=(2, 2), filter_size=(3, 3), inits=[<treeano.inits.HeUniformInit object at 0x7fbdf7e21910>], num_units=256, dropout_probability=0.5, num_filters=32)
| | | SequentialNode(name=u'seq')
| | | | InputNode(name=u'x', shape=(None, 1, 60, 60))
| | | | UpdateScaleNode(name=u'st_update_scale', update_scale_factor=0.5)
| | | | | AffineSpatialTransformerNode(name=u'st', output_shape=(20, 20))
| | | | | | HyperparameterNode(name=u'loc', filter_size=(5, 5), num_filters=20, pool_size=(2, 2))
| | | | | | | SequentialNode(name=u'loc_seq')
| | | | | | | | MaxPool2DDNNNode(name=u'loc_pool1')
| | | | | | | | Conv2DDNNNode(name=u'loc_conv1')
| | | | | | | | MaxPool2DDNNNode(name=u'loc_pool2')
| | | | | | | | NoScaleBatchNormalizationNode(name=u'loc_bn1')
| | | | | | | | ReLUNode(name=u'loc_relu1')
| | | | | | | | Conv2DDNNNode(name=u'loc_conv2')
| | | | | | | | NoScaleBatchNormalizationNode(name=u'loc_bn2')
| | | | | | | | ReLUNode(name=u'loc_relu2')
| | | | | | | | DenseNode(name=u'loc_fc1', num_units=50)
| | | | | | | | NoScaleBatchNormalizationNode(name=u'loc_bn3')
| | | | | | | | ReLUNode(name=u'loc_relu3')
| | | | | | | | DenseNode(name=u'loc_fc2', inits=[<treeano.inits.NormalWeightInit object at 0x7fbdf7e33e10>], num_units=6)
| | | | Conv2DNode(name=u'conv1')
| | | | MaxPool2DNode(name=u'mp1')
| | | | NoScaleBatchNormalizationNode(name=u'bn1')
| | | | ReLUNode(name=u'relu1')
| | | | Conv2DNode(name=u'conv2')
| | | | MaxPool2DNode(name=u'mp2')
| | | | NoScaleBatchNormalizationNode(name=u'bn2')
| | | | ReLUNode(name=u'relu2')
| | | | GaussianDropoutNode(name=u'do1')
| | | | DenseNode(name=u'fc1')
| | | | NoScaleBatchNormalizationNode(name=u'bn3')
| | | | ReLUNode(name=u'relu3')
| | | | DenseNode(name=u'fc2', num_units=10)
| | | | SoftmaxNode(name=u'pred')
CPU times: user 557 ms, sys: 16 ms, total: 573 ms
Wall time: 569 ms
In [13]:
%%time
# train the network
helpers.train_network(network, in_train, in_valid, max_iters=50)
build took 0.5095s
network_compile took 3.3815s
compile_function took 3.3852s
build took 0.0000s
network_compile took 13.1547s
compile_function took 13.1588s
Starting training...
Beginning evaluate_until
generating_data took 0.0000s
data_transfer took 0.0842s
data_free took 0.0001s
data_transfer took 0.0288s
data_free took 0.0001s
1: train_cost: 2.637 valid_cost: 1.471 valid_accuracy: 0.506
generating_data took 0.0000s
data_transfer took 0.0803s
data_free took 0.0001s
data_transfer took 0.0289s
data_free took 0.0001s
2: train_cost: 1.525 valid_cost: 0.913 valid_accuracy: 0.707
generating_data took 0.0000s
data_transfer took 0.0776s
data_free took 0.0001s
data_transfer took 0.0306s
data_free took 0.0002s
3: train_cost: 1.077 valid_cost: 0.648 valid_accuracy: 0.791
generating_data took 0.0000s
data_transfer took 0.0812s
data_free took 0.0001s
data_transfer took 0.0291s
data_free took 0.0001s
4: train_cost: 0.833 valid_cost: 0.513 valid_accuracy: 0.840
generating_data took 0.0000s
data_transfer took 0.0801s
data_free took 0.0001s
data_transfer took 0.0301s
data_free took 0.0001s
5: train_cost: 0.707 valid_cost: 0.424 valid_accuracy: 0.862
generating_data took 0.0000s
data_transfer took 0.0851s
data_free took 0.0001s
data_transfer took 0.0294s
data_free took 0.0001s
6: train_cost: 0.626 valid_cost: 0.405 valid_accuracy: 0.875
generating_data took 0.0000s
data_transfer took 0.0825s
data_free took 0.0001s
data_transfer took 0.0296s
data_free took 0.0001s
7: train_cost: 0.564 valid_cost: 0.374 valid_accuracy: 0.886
generating_data took 0.0000s
data_transfer took 0.0824s
data_free took 0.0001s
data_transfer took 0.0300s
data_free took 0.0001s
8: train_cost: 0.502 valid_cost: 0.321 valid_accuracy: 0.899
generating_data took 0.0000s
data_transfer took 0.0800s
data_free took 0.0002s
data_transfer took 0.0295s
data_free took 0.0001s
9: train_cost: 0.493 valid_cost: 0.303 valid_accuracy: 0.904
generating_data took 0.0000s
data_transfer took 0.0818s
data_free took 0.0001s
data_transfer took 0.0292s
data_free took 0.0001s
10: train_cost: 0.445 valid_cost: 0.266 valid_accuracy: 0.915
generating_data took 0.0000s
data_transfer took 0.0801s
data_free took 0.0001s
data_transfer took 0.0303s
data_free took 0.0001s
11: train_cost: 0.423 valid_cost: 0.281 valid_accuracy: 0.915
generating_data took 0.0000s
data_transfer took 0.0854s
data_free took 0.0001s
data_transfer took 0.0296s
data_free took 0.0001s
12: train_cost: 0.395 valid_cost: 0.257 valid_accuracy: 0.917
generating_data took 0.0000s
data_transfer took 0.0945s
data_free took 0.0001s
data_transfer took 0.0297s
data_free took 0.0001s
13: train_cost: 0.381 valid_cost: 0.252 valid_accuracy: 0.923
generating_data took 0.0000s
data_transfer took 0.0819s
data_free took 0.0001s
data_transfer took 0.0296s
data_free took 0.0001s
14: train_cost: 0.367 valid_cost: 0.230 valid_accuracy: 0.927
generating_data took 0.0000s
data_transfer took 0.0850s
data_free took 0.0001s
data_transfer took 0.0298s
data_free took 0.0001s
15: train_cost: 0.359 valid_cost: 0.225 valid_accuracy: 0.929
generating_data took 0.0000s
data_transfer took 0.0822s
data_free took 0.0001s
data_transfer took 0.0503s
data_free took 0.0001s
16: train_cost: 0.356 valid_cost: 0.225 valid_accuracy: 0.927
generating_data took 0.0000s
data_transfer took 0.1611s
data_free took 0.0001s
data_transfer took 0.0294s
data_free took 0.0002s
17: train_cost: 0.340 valid_cost: 0.208 valid_accuracy: 0.934
generating_data took 0.0000s
data_transfer took 0.0837s
data_free took 0.0002s
data_transfer took 0.0298s
data_free took 0.0001s
18: train_cost: 0.339 valid_cost: 0.204 valid_accuracy: 0.936
generating_data took 0.0000s
data_transfer took 0.0814s
data_free took 0.0002s
data_transfer took 0.0301s
data_free took 0.0001s
19: train_cost: 0.324 valid_cost: 0.210 valid_accuracy: 0.934
generating_data took 0.0000s
data_transfer took 0.0841s
data_free took 0.0002s
data_transfer took 0.0303s
data_free took 0.0001s
20: train_cost: 0.321 valid_cost: 0.207 valid_accuracy: 0.934
generating_data took 0.0000s
data_transfer took 0.0829s
data_free took 0.0001s
data_transfer took 0.0468s
data_free took 0.0001s
21: train_cost: 0.304 valid_cost: 0.193 valid_accuracy: 0.943
generating_data took 0.0000s
data_transfer took 0.1643s
data_free took 0.0002s
data_transfer took 0.0303s
data_free took 0.0001s
22: train_cost: 0.303 valid_cost: 0.205 valid_accuracy: 0.938
generating_data took 0.0000s
data_transfer took 0.1854s
data_free took 0.0002s
data_transfer took 0.0298s
data_free took 0.0001s
23: train_cost: 0.321 valid_cost: 0.210 valid_accuracy: 0.934
generating_data took 0.0000s
data_transfer took 0.0835s
data_free took 0.0013s
data_transfer took 0.0318s
data_free took 0.0001s
24: train_cost: 0.296 valid_cost: 0.188 valid_accuracy: 0.941
generating_data took 0.0000s
data_transfer took 0.0852s
data_free took 0.0002s
data_transfer took 0.0296s
data_free took 0.0001s
25: train_cost: 0.298 valid_cost: 0.190 valid_accuracy: 0.940
generating_data took 0.0000s
data_transfer took 0.0830s
data_free took 0.0001s
data_transfer took 0.0299s
data_free took 0.0001s
26: train_cost: 0.283 valid_cost: 0.172 valid_accuracy: 0.946
generating_data took 0.0000s
data_transfer took 0.0831s
data_free took 0.0002s
data_transfer took 0.0303s
data_free took 0.0001s
27: train_cost: 0.290 valid_cost: 0.183 valid_accuracy: 0.943
generating_data took 0.0000s
data_transfer took 0.0850s
data_free took 0.0001s
data_transfer took 0.0296s
data_free took 0.0001s
28: train_cost: 0.266 valid_cost: 0.167 valid_accuracy: 0.949
generating_data took 0.0000s
data_transfer took 0.0827s
data_free took 0.0001s
data_transfer took 0.0299s
data_free took 0.0001s
29: train_cost: 0.264 valid_cost: 0.176 valid_accuracy: 0.946
generating_data took 0.0000s
data_transfer took 0.0829s
data_free took 0.0002s
data_transfer took 0.0308s
data_free took 0.0001s
30: train_cost: 0.275 valid_cost: 0.170 valid_accuracy: 0.948
generating_data took 0.0000s
data_transfer took 0.0851s
data_free took 0.0001s
data_transfer took 0.0295s
data_free took 0.0001s
31: train_cost: 0.273 valid_cost: 0.189 valid_accuracy: 0.942
generating_data took 0.0000s
data_transfer took 0.0836s
data_free took 0.0001s
data_transfer took 0.0296s
data_free took 0.0001s
32: train_cost: 0.260 valid_cost: 0.160 valid_accuracy: 0.950
generating_data took 0.0000s
data_transfer took 0.0817s
data_free took 0.0001s
data_transfer took 0.0298s
data_free took 0.0001s
33: train_cost: 0.248 valid_cost: 0.155 valid_accuracy: 0.952
generating_data took 0.0000s
data_transfer took 0.0830s
data_free took 0.0002s
data_transfer took 0.0301s
data_free took 0.0001s
34: train_cost: 0.248 valid_cost: 0.161 valid_accuracy: 0.952
generating_data took 0.0000s
data_transfer took 0.0821s
data_free took 0.0001s
data_transfer took 0.0298s
data_free took 0.0001s
35: train_cost: 0.250 valid_cost: 0.169 valid_accuracy: 0.947
generating_data took 0.0000s
data_transfer took 0.0810s
data_free took 0.0001s
data_transfer took 0.0296s
data_free took 0.0001s
36: train_cost: 0.242 valid_cost: 0.162 valid_accuracy: 0.949
generating_data took 0.0000s
data_transfer took 0.0825s
data_free took 0.0001s
data_transfer took 0.0298s
data_free took 0.0001s
37: train_cost: 0.258 valid_cost: 0.161 valid_accuracy: 0.951
generating_data took 0.0000s
data_transfer took 0.0806s
data_free took 0.0001s
data_transfer took 0.0298s
data_free took 0.0001s
38: train_cost: 0.248 valid_cost: 0.141 valid_accuracy: 0.955
generating_data took 0.0000s
data_transfer took 0.0843s
data_free took 0.0001s
data_transfer took 0.0304s
data_free took 0.0001s
39: train_cost: 0.234 valid_cost: 0.177 valid_accuracy: 0.944
generating_data took 0.0000s
data_transfer took 0.0859s
data_free took 0.0001s
data_transfer took 0.0294s
data_free took 0.0001s
40: train_cost: 0.246 valid_cost: 0.150 valid_accuracy: 0.952
generating_data took 0.0000s
data_transfer took 0.0829s
data_free took 0.0001s
data_transfer took 0.0302s
data_free took 0.0002s
41: train_cost: 0.249 valid_cost: 0.151 valid_accuracy: 0.951
generating_data took 0.0000s
data_transfer took 0.0820s
data_free took 0.0001s
data_transfer took 0.0295s
data_free took 0.0001s
42: train_cost: 0.238 valid_cost: 0.154 valid_accuracy: 0.950
generating_data took 0.0000s
data_transfer took 0.0810s
data_free took 0.0001s
data_transfer took 0.0295s
data_free took 0.0001s
43: train_cost: 0.239 valid_cost: 0.150 valid_accuracy: 0.954
generating_data took 0.0000s
data_transfer took 0.0826s
data_free took 0.0001s
data_transfer took 0.0294s
data_free took 0.0001s
44: train_cost: 0.223 valid_cost: 0.148 valid_accuracy: 0.953
generating_data took 0.0000s
data_transfer took 0.0826s
data_free took 0.0001s
data_transfer took 0.0293s
data_free took 0.0001s
45: train_cost: 0.230 valid_cost: 0.149 valid_accuracy: 0.953
generating_data took 0.0000s
data_transfer took 0.0806s
data_free took 0.0001s
data_transfer took 0.0299s
data_free took 0.0001s
46: train_cost: 0.218 valid_cost: 0.147 valid_accuracy: 0.953
generating_data took 0.0000s
data_transfer took 0.0829s
data_free took 0.0003s
data_transfer took 0.0291s
data_free took 0.0001s
47: train_cost: 0.231 valid_cost: 0.157 valid_accuracy: 0.950
generating_data took 0.0000s
data_transfer took 0.0797s
data_free took 0.0001s
data_transfer took 0.0296s
data_free took 0.0001s
48: train_cost: 0.227 valid_cost: 0.132 valid_accuracy: 0.959
generating_data took 0.0000s
data_transfer took 0.0830s
data_free took 0.0001s
data_transfer took 0.0294s
data_free took 0.0001s
49: train_cost: 0.219 valid_cost: 0.147 valid_accuracy: 0.954
generating_data took 0.0000s
data_transfer took 0.0805s
data_free took 0.0001s
data_transfer took 0.0301s
data_free took 0.0001s
50: train_cost: 0.211 valid_cost: 0.131 valid_accuracy: 0.958
generating_data took 0.0000s
CPU times: user 2min 52s, sys: 30.5 s, total: 3min 22s
Wall time: 3min 22s
In [14]:
%%time
# compile evaluation function
test_fn = helpers.test_fn(network)
build took 1.3061s
network_compile took 1.4400s
compile_function took 1.4417s
CPU times: user 2.62 s, sys: 136 ms, total: 2.76 s
Wall time: 2.75 s
In [15]:
# results
plt.figure(figsize=(7, 14))
for i in range(3):
plt.subplot(321+i*2)
plt.imshow(in_test["x"][i, 0], cmap='gray', interpolation='none')
if i == 0:
plt.title('Original 60x60', fontsize=20)
plt.axis('off')
plt.subplot(322+i*2)
plt.imshow(test_fn(in_test)["transformed"][i, 0],
cmap='gray',
interpolation='none')
if i == 0:
plt.title('Transformed 20x20', fontsize=20)
plt.axis('off')
plt.tight_layout()
data_transfer took 0.0235s
data_free took 0.0001s
data_transfer took 0.0199s
data_free took 0.0001s
data_transfer took 0.0188s
data_free took 0.0001s
In [ ]:
Content source: nsauder/treeano
Similar notebooks: