In this notebook, we will check how the network trained on ordinary data copes with the augmented data and what will happen if it is learned from the augmented data.
How the implement class with the neural network you'll see in this file.
In [1]:
import sys
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook as tqn
%matplotlib inline
sys.path.append('../../..')
sys.path.append('../../utils')
import utils
from secondbatch import MnistBatch
from simple_conv_model import ConvModel
from batchflow import V, B
from batchflow.opensets import MNIST
Create batch class depended from MnistBatch
In [2]:
mnistset = MNIST(batch_class=MnistBatch)
Already familiar to us the construction to create the pipelines. These pipelines train NN on simple MNIST images, without shift.
In [3]:
normal_train_ppl = (
mnistset.train.p
.init_model('dynamic',
ConvModel,
'conv',
config={'inputs': dict(images={'shape': (28, 28, 1)},
labels={'classes': (10),
'transform': 'ohe',
'name': 'targets'}),
'loss': 'ce',
'optimizer':'Adam',
'input_block/inputs': 'images',
'head/units': 10,
'output': dict(ops=['labels',
'proba',
'accuracy'])})
.train_model('conv',
feed_dict={'images': B('images'),
'labels': B('labels')})
)
In [4]:
normal_test_ppl = (
mnistset.test.p
.import_model('conv', normal_train_ppl)
.init_variable('test_accuracy', init_on_each_run=int)
.predict_model('conv',
fetches='output_accuracy',
feed_dict={'images': B('images'),
'labels': B('labels')},
save_to=V('test_accuracy'),
mode='w'))
Train the model by using next_batch method
In [5]:
batch_size = 400
for i in tqn(range(600)):
normal_train_ppl.next_batch(batch_size, n_epochs=None)
normal_test_ppl.next_batch(batch_size, n_epochs=None)
Get variable from pipeline and print accuracy on data without shift
In [6]:
acc = normal_test_ppl.get_variable('test_accuracy')
print('Accuracy on normal data: {:.2%}'.format(acc))
Now check, how change accuracy, if the first model testing on shift data
In [7]:
shift_test_ppl= (
mnistset.test.p
.import_model('conv', normal_train_ppl)
.shift_flattened_pic()
.init_variable('predict', init_on_each_run=int)
.predict_model('conv',
fetches='output_accuracy',
feed_dict={'images': B('images'),
'labels': B('labels')},
save_to=V('predict'),
mode='w')
.run(batch_size, n_epochs=1)
)
In [8]:
print('Accuracy with shift: {:.2%}'.format(shift_test_ppl.get_variable('predict')))
In order for the model to be able to predict the augmentation data, we will teach it on such data
In [9]:
shift_train_ppl = (
mnistset.train.p
.shift_flattened_pic()
.init_model('dynamic',
ConvModel,
'conv',
config={'inputs': dict(images={'shape': (28, 28, 1)},
labels={'classes': (10),
'transform': 'ohe',
'name': 'targets'}),
'loss': 'ce',
'optimizer':'Adam',
'input_block/inputs': 'images',
'head/units': 10,
'output': dict(ops=['labels',
'proba',
'accuracy'])})
.train_model('conv',
feed_dict={'images': B('images'),
'labels': B('labels')})
)
In [10]:
for i in tqn(range(600)):
shift_train_ppl.next_batch(batch_size, n_epochs=None)
And now check, how change accuracy on shift data
In [11]:
shift_test_ppl = (
mnistset.test.p
.import_model('conv', shift_train_ppl)
.shift_flattened_pic()
.init_variable('acc', init_on_each_run=list)
.init_variable('img', init_on_each_run=list)
.init_variable('predict', init_on_each_run=list)
.predict_model('conv',
fetches=['output_accuracy', 'inputs', 'output_proba'],
feed_dict={'images': B('images'),
'labels': B('labels')},
save_to=[V('acc'), V('img'), V('predict')],
mode='a')
.run(1, n_epochs=1)
)
In [12]:
print('Accuracy with shift: {:.2%}'.format(np.mean(shift_test_ppl.get_variable('acc'))))
It's really better than before.
It is interesting, on what figures we are mistaken?
In [13]:
acc = shift_test_ppl.get_variable('acc')
img = shift_test_ppl.get_variable('img')
predict = shift_test_ppl.get_variable('predict')
In [14]:
_, ax = plt.subplots(3, 4, figsize=(16, 16))
ax = ax.reshape(-1)
for i in range(12):
index = np.where(np.array(acc) == 0)[0][i]
ax[i].imshow(img[index]['images'].reshape(-1,28))
ax[i].set_xlabel('Predict: {}'.format(np.argmax(predict[index][0])), fontsize=18)
ax[i].grid()
In most cases, the model is mistaken for examples in which the figures are heavily shifted, some of them are even hardly recognizable by eye.
If you still have not completed our tutorial, you can fix it right now! Read and apply another experiments: