In [1]:
import sys
sys.path.insert(0, '../')
In [2]:
import time, pickle
from sklearn.externals import joblib
import numpy as np
np.set_printoptions(precision=3, linewidth=200, suppress=True)
from sklearn.metrics import classification_report, accuracy_score
In [3]:
from library.datasets.cifar10 import CIFAR10
from library.plot_tools import plot
from library.utils import file_utils
In [6]:
dataset = 'cifar10'
file_no = 203
exp_no = 1
total_time = 0
In [7]:
model_folder = '../logs/' + dataset + '/' + str(file_no).zfill(3) + '_tfl_cnn/exp_no_' + str(exp_no).zfill(3) + '/'
model_name = nodel_folder + 'residual_net_classifier.pkl'
In [10]:
train_val_split_data = None
one_hot = True
num_images_required = 0.0
transform = True
transform_method = 'StandardScaler'
In [11]:
start = time.time()
one_hot = True
cifar10 = CIFAR10(one_hot_encode=one_hot, num_images=num_images_required, preprocess='StandardScaler',
train_validate_split=train_val_split_data, endian='little')
cifar10.load_data(train=False, test=True, data_directory='./datasets/cifar10/')
end = time.time()
print('[ Step 1] Loaded CIFAR 10 Dataset in %.4f ms' %((end-start)*1000))
total_time += (end-start)
In [ ]:
model = joblib.load(model_name)
In [ ]:
prediction_numbers = model.predict(cifar10.test.images)
In [17]:
prediction_classes = []
num_test_images = cifar10.test.data.shape[0]
for i in range(num_test_images):
prediction_classes.append(cifar10.classes[int(prediction_numbers[i])])
In [18]:
test_accuracy = accuracy_score(cifar10.test.class_labels, prediction_numbers)
print('Accuracy of the classifier on test dataset: %.4f' % test_accuracy)
In [19]:
cifar10.plot_images(cifar10.test.data[:50], cifar10.test.class_names[:50], cls_pred=prediction_classes[:50],
nrows=5, ncols=10, fig_size=(20,50), fontsize=30, convert=True)
Out[19]:
In [20]:
plot.plot_confusion_matrix(cifar10.test.class_labels, prediction_numbers, classes=cifar10.classes,
normalize=True, title='Confusion matrix')
In [21]:
print('Detailed classification report')
print(classification_report(y_true=cifar10.test.class_labels, y_pred=prediction_numbers,
target_names=cifar10.classes))
In [22]:
sess.close()
In [ ]: