MNIST Analysis with Distributed Keras

Joeri Hermans (Technical Student, IT-DB-SAS, CERN)
Departement of Knowledge Engineering
Maastricht University, The Netherlands


In [1]:
!(date +%d\ %B\ %G)


18 January 2017

In this notebook we will show you how to process the MNIST dataset using Distributed Keras. As in the workflow notebook, we will guide you through the complete machine learning pipeline.

Preparation

To get started, we first load all the required imports. Please make sure you installed dist-keras, and seaborn. Furthermore, we assume that you have access to an installation which provides Apache Spark.

Before you start this notebook, place make sure you ran the "MNIST preprocessing" notebook first, since we will be evaluating a manually "enlarged dataset".


In [2]:
%matplotlib inline

import numpy as np

from keras.optimizers import *
from keras.models import Sequential
from keras.layers.core import *
from keras.layers.convolutional import *

from pyspark import SparkContext
from pyspark import SparkConf

from matplotlib import pyplot as plt

from pyspark import StorageLevel

from pyspark.ml.feature import StandardScaler
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.feature import OneHotEncoder
from pyspark.ml.feature import MinMaxScaler
from pyspark.ml.feature import StringIndexer
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

from distkeras.trainers import *
from distkeras.predictors import *
from distkeras.transformers import *
from distkeras.evaluators import *
from distkeras.utils import *


Using TensorFlow backend.

In the following cell, adapt the parameters to fit your personal requirements.


In [3]:
# Modify these variables according to your needs.
application_name = "Distributed Keras MNIST Analysis"
using_spark_2 = False
local = False
path = "mnist.parquet"
if local:
    # Tell master to use local resources.
    master = "local[*]"
    num_processes = 3
    num_executors = 1
else:
    # Tell master to use YARN.
    master = "yarn-client"
    num_executors = 30
    num_processes = 1

In [4]:
# This variable is derived from the number of cores and executors, and will be used to assign the number of model trainers.
num_workers = num_executors * num_processes

print("Number of desired executors: " + `num_executors`)
print("Number of desired processes / executor: " + `num_processes`)
print("Total number of workers: " + `num_workers`)


Number of desired executors: 30
Number of desired processes / executor: 1
Total number of workers: 30

In [5]:
conf = SparkConf()
conf.set("spark.app.name", application_name)
conf.set("spark.master", master)
conf.set("spark.executor.cores", `num_processes`)
conf.set("spark.executor.instances", `num_executors`)
conf.set("spark.locality.wait", "0")
conf.set("spark.executor.memory", "5g")
conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer");

# Check if the user is running Spark 2.0 +
if using_spark_2:
    sc = SparkSession.builder.config(conf=conf) \
            .appName(application_name) \
            .getOrCreate()
else:
    # Create the Spark context.
    sc = SparkContext(conf=conf)
    # Add the missing imports
    from pyspark import SQLContext
    sqlContext = SQLContext(sc)

In [6]:
# Check if we are using Spark 2.0
if using_spark_2:
    reader = sc
else:
    reader = sqlContext
# Read the training and test set.
training_set = reader.read.parquet('data/mnist_train_big.parquet') \
                     .select("features_normalized_dense", "label_encoded", "label")
test_set = reader.read.parquet('data/mnist_test_preprocessed.parquet') \
                 .select("features_normalized_dense", "label_encoded", "label")

In [7]:
# Print the schema of the dataset.
training_set.printSchema()


root
 |-- features_normalized_dense: vector (nullable = true)
 |-- label_encoded: vector (nullable = true)
 |-- label: long (nullable = true)

Model Development

Multilayer Perceptron


In [8]:
mlp = Sequential()
mlp.add(Dense(1000, input_shape=(784,)))
mlp.add(Activation('relu'))
mlp.add(Dropout(0.2))
mlp.add(Dense(200))
mlp.add(Activation('relu'))
mlp.add(Dropout(0.2))
mlp.add(Dense(10))
mlp.add(Activation('softmax'))

In [9]:
mlp.summary()


____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
dense_1 (Dense)                  (None, 1000)          785000      dense_input_1[0][0]              
____________________________________________________________________________________________________
activation_1 (Activation)        (None, 1000)          0           dense_1[0][0]                    
____________________________________________________________________________________________________
dropout_1 (Dropout)              (None, 1000)          0           activation_1[0][0]               
____________________________________________________________________________________________________
dense_2 (Dense)                  (None, 200)           200200      dropout_1[0][0]                  
____________________________________________________________________________________________________
activation_2 (Activation)        (None, 200)           0           dense_2[0][0]                    
____________________________________________________________________________________________________
dropout_2 (Dropout)              (None, 200)           0           activation_2[0][0]               
____________________________________________________________________________________________________
dense_3 (Dense)                  (None, 10)            2010        dropout_2[0][0]                  
____________________________________________________________________________________________________
activation_3 (Activation)        (None, 10)            0           dense_3[0][0]                    
====================================================================================================
Total params: 987210
____________________________________________________________________________________________________

In [10]:
optimizer_mlp = 'adam'
loss_mlp = 'categorical_crossentropy'

Training

Prepare the training and test set for evaluation and training.


In [11]:
training_set = training_set.repartition(num_workers)
test_set = test_set.repartition(num_workers)
training_set.cache()
test_set.cache()
print("Number of training instances: " + str(training_set.count()))
print("Number of testing instances: " + str(test_set.count()))


Number of training instances: 6060000
Number of testing instances: 10000

Evaluation

We define a utility function which will compute the accuracy for us.


In [12]:
def evaluate_accuracy(model, test_set, features="features_normalized_dense"):
    evaluator = AccuracyEvaluator(prediction_col="prediction_index", label_col="label")
    predictor = ModelPredictor(keras_model=model, features_col=features)
    transformer = LabelIndexTransformer(output_dim=10)
    test_set = test_set.select(features, "label")
    test_set = predictor.predict(test_set)
    test_set = transformer.transform(test_set)
    score = evaluator.evaluate(test_set)
    
    return score

ADAG


In [ ]:
trainer = ADAG(keras_model=mlp, worker_optimizer=optimizer_mlp, loss=loss_mlp, num_workers=num_workers,
               batch_size=4, communication_window=5, num_epoch=1,
               features_col="features_normalized_dense", label_col="label_encoded")
# Modify the default parallelism factor.
trained_model = trainer.train(training_set)

In [19]:
# View the weights of the trained model.
trained_model.get_weights()


Out[19]:
[array([[-0.02490237, -0.01861665,  0.03102627, ...,  0.01722135,
          0.02223415, -0.04933412],
        [-0.02634868,  0.03564246, -0.05392314, ..., -0.02999102,
         -0.01270337, -0.03888189],
        [ 0.00727941,  0.04553502, -0.01856072, ...,  0.0319587 ,
         -0.00354035, -0.03581727],
        ..., 
        [-0.03245988, -0.01220334,  0.019447  , ...,  0.05723321,
         -0.05618715, -0.0248918 ],
        [-0.02532675, -0.01772211,  0.05514754, ...,  0.03839124,
         -0.05036234, -0.03766601],
        [ 0.04610632,  0.01409597,  0.03790993, ..., -0.02038677,
         -0.03649681,  0.04742099]], dtype=float32),
 array([ -1.29682487e-02,   1.38744503e-01,  -3.10007334e-01,
         -3.04996595e-02,  -1.39434069e-01,  -4.05185074e-02,
         -2.09797233e-01,  -4.62490469e-01,  -6.72216356e-01,
         -1.83647368e-02,  -2.93090612e-01,   5.11649624e-02,
         -2.74094105e-01,  -9.03906003e-02,  -7.21242726e-01,
         -2.51375604e-02,  -1.40052319e-01,  -1.31754786e-01,
         -1.88921779e-01,  -3.18406552e-01,  -3.45931239e-02,
         -1.89292878e-01,   3.80539931e-02,   3.54425013e-02,
         -6.34538352e-01,  -2.27093436e-02,  -5.49978614e-01,
         -2.85222325e-02,  -4.87636119e-01,  -2.94719964e-01,
         -4.62469608e-01,  -4.31859016e-01,  -4.95594800e-01,
         -7.55963206e-01,  -7.07836151e-01,   5.50588481e-02,
          1.01570776e-02,  -3.62383217e-01,  -2.37895608e-01,
         -3.48139226e-01,  -5.14193960e-02,  -4.49353665e-01,
         -2.04702299e-02,  -1.28980473e-01,  -6.01515993e-02,
         -4.11046803e-01,  -2.73511171e-01,  -4.22501177e-01,
          6.57678917e-02,  -3.77899945e-01,  -3.68858546e-01,
         -3.45079124e-01,  -1.21501423e-01,  -2.59954304e-01,
         -2.77339309e-01,   7.24700987e-02,  -1.75704360e-01,
         -1.79602101e-01,  -3.49472016e-01,  -4.22441006e-01,
         -3.98772031e-01,   4.78056073e-02,   1.63912345e-02,
         -1.73481293e-02,   2.03711018e-01,  -1.66458517e-01,
         -2.50248574e-02,  -4.33256328e-01,  -1.77355483e-02,
         -6.68845698e-02,  -6.33655787e-02,  -2.07219645e-01,
         -2.81381667e-01,  -2.10354477e-01,   9.65033993e-02,
          1.45252123e-01,  -1.62108362e-01,  -4.10078391e-02,
         -5.01093924e-01,   6.61657602e-02,  -3.54006797e-01,
         -2.72664815e-01,  -4.63590562e-01,  -2.76888013e-01,
          5.67168836e-03,  -1.63264722e-02,  -5.64372167e-02,
         -3.27719487e-02,  -1.25738844e-01,  -3.16582769e-02,
         -3.16652000e-01,   2.20678657e-01,  -4.90398854e-01,
         -3.87180448e-01,   4.62217331e-02,  -3.87124509e-01,
          3.44271868e-01,  -6.47646427e-01,  -4.47504744e-02,
         -3.12687427e-01,  -3.64519686e-01,  -1.19691178e-01,
         -1.22579239e-01,  -1.74031451e-01,  -3.50467891e-01,
         -3.85930926e-01,  -1.01258140e-02,   1.65355578e-01,
          2.38174275e-02,  -3.86843532e-01,  -2.11541757e-01,
         -1.60455573e-02,  -3.41660500e-01,  -2.41097137e-01,
         -3.58184397e-01,  -3.74646991e-01,  -5.68306029e-01,
          6.03663735e-02,  -2.25287676e-01,  -3.33954960e-01,
         -3.21863830e-01,  -5.74063025e-02,  -9.54797715e-02,
         -1.69863552e-01,   5.25663458e-02,  -1.78944767e-01,
         -4.96068239e-01,  -9.37457308e-02,  -4.91037033e-02,
         -5.45800686e-01,  -4.19147074e-01,  -3.63402218e-01,
         -9.55256671e-02,  -6.56951070e-02,  -4.74279895e-02,
          3.94136347e-02,  -6.89108312e-01,  -6.40569270e-01,
         -2.92730868e-01,  -4.21674043e-01,  -9.05798003e-02,
         -9.85799953e-02,  -3.34262311e-01,  -2.91352630e-01,
         -1.20481804e-01,  -1.30824670e-01,  -3.15101117e-01,
         -3.82897407e-01,  -3.67818296e-01,  -2.51174152e-01,
         -4.45220284e-02,  -3.63316804e-01,  -5.95236719e-01,
         -3.27549487e-01,  -5.18906057e-01,  -1.80942759e-01,
         -1.93147764e-01,  -1.63675278e-01,   5.25709763e-02,
         -1.69222236e-01,  -1.66612849e-01,  -1.89764783e-01,
          9.59388837e-02,  -1.79865390e-01,  -2.87416220e-01,
         -1.37040511e-01,  -3.68917108e-01,  -1.97503880e-01,
         -4.80307907e-01,  -9.74704884e-03,  -1.62035048e-01,
         -4.33685966e-02,  -3.75206321e-01,  -2.71574229e-01,
         -2.51338482e-01,  -1.91602707e-01,  -4.66123730e-01,
         -3.09535444e-01,  -3.18885483e-02,  -3.23637798e-02,
         -3.71796012e-01,  -2.26407617e-01,  -4.69909385e-02,
         -3.70391518e-01,  -5.37406743e-01,  -5.00004053e-01,
         -4.49130647e-02,   1.55784473e-01,  -3.39550585e-01,
         -5.15295863e-01,  -5.79936266e-01,   4.80024889e-03,
         -1.23718642e-01,  -6.55675307e-02,  -2.74233013e-01,
         -2.67147571e-01,  -4.20176655e-01,  -2.30046362e-02,
         -2.80579627e-01,  -6.52074635e-01,  -2.07271874e-01,
         -3.34823787e-01,  -5.11079669e-01,  -4.89039391e-01,
         -1.69896662e-01,  -6.09769404e-01,   1.67333558e-01,
         -1.52619872e-02,  -1.82103708e-01,  -1.59035064e-02,
         -2.82586038e-01,  -4.48576622e-02,  -2.77401984e-01,
         -1.18868940e-01,  -3.09958905e-01,  -4.54939663e-01,
         -6.84868218e-03,  -1.78479820e-01,  -4.12694991e-01,
         -4.86943096e-01,  -4.83419180e-01,  -2.92061418e-01,
         -3.56696308e-01,  -2.38492072e-01,  -1.99521467e-01,
         -6.62643433e-01,  -6.58789635e-01,  -3.13386142e-01,
         -2.39210613e-02,   3.81695509e-01,   3.89514342e-02,
         -4.21914130e-01,  -1.78643346e-01,  -3.58139843e-01,
         -2.31155585e-02,  -5.25866091e-01,  -2.01350115e-02,
          1.34515122e-01,  -4.72941786e-01,   1.28511051e-02,
         -1.92628369e-01,  -2.94919074e-01,  -1.21810228e-01,
         -2.63900816e-01,  -1.77175865e-01,  -3.85966711e-02,
         -3.91167760e-01,  -3.54940116e-01,  -4.08377945e-02,
         -2.46946454e-01,  -1.70614153e-01,   9.64559093e-02,
         -1.58487067e-01,  -1.40857771e-01,  -2.60191988e-02,
         -2.16996279e-02,  -2.01046526e-01,   1.07773796e-01,
         -7.25519285e-02,  -4.59324010e-02,  -3.97602469e-01,
         -2.86683738e-01,  -2.06594560e-02,  -2.32254282e-01,
         -1.47455707e-01,  -2.11738929e-01,  -3.97648931e-01,
         -1.92232862e-01,  -4.22664315e-01,  -2.10082695e-01,
         -3.69767874e-01,  -3.35989922e-01,  -2.50372291e-02,
         -2.56772131e-01,  -7.55918026e-01,  -1.45749766e-02,
         -5.94904542e-01,  -1.83992922e-01,  -1.98239967e-01,
          2.28624657e-01,  -3.67346585e-01,  -2.17467710e-01,
         -8.19451883e-02,  -5.01424968e-02,  -3.00576668e-02,
          2.42029456e-03,  -6.11475348e-01,  -2.48637870e-01,
         -1.25368005e-02,  -1.07831452e-02,   3.56794626e-01,
         -2.73973256e-01,  -5.00894673e-02,  -3.93987626e-01,
         -6.70151055e-01,   5.03201634e-02,  -3.47819924e-01,
          2.21592330e-04,  -9.35477093e-02,  -4.01370734e-01,
         -5.17268419e-01,  -2.08003540e-02,  -1.58300679e-02,
          1.09454863e-01,   4.86627640e-03,  -4.40006703e-01,
          1.10145152e-01,  -3.08435559e-01,  -2.27646939e-02,
         -6.15591705e-02,  -6.83150813e-02,   1.51192188e-01,
         -2.93954074e-01,   1.76271528e-01,  -5.47897398e-01,
         -2.94454783e-01,  -4.87583935e-01,  -2.25682836e-02,
         -2.61891991e-01,  -2.05876276e-01,  -2.91871820e-02,
         -4.65158612e-01,  -1.10427953e-01,   2.59957045e-01,
         -6.44603491e-01,  -5.89241982e-01,  -2.40099952e-01,
         -2.48620026e-02,   2.60877088e-02,  -3.69062722e-01,
         -5.85998118e-01,   6.35902397e-04,   1.52950898e-01,
         -1.31705374e-01,  -6.95600629e-01,  -6.93177283e-02,
         -3.34524751e-01,  -2.05166377e-02,  -4.04433101e-01,
         -3.34488690e-01,   4.12484966e-02,  -1.07743412e-01,
         -2.31767640e-01,  -5.87181449e-01,  -1.24916852e-01,
         -2.45317779e-02,  -4.82061923e-01,   4.29915352e-04,
         -2.29062542e-01,  -1.53157920e-01,  -8.75511765e-02,
         -1.93034634e-01,  -2.39149824e-01,  -2.81021118e-01,
         -1.92091212e-01,   4.84096706e-02,  -3.15482467e-01,
         -9.38970945e-04,  -7.32823536e-02,   1.46180347e-01,
         -7.48398662e-01,  -2.95927972e-01,  -1.01935327e-01,
         -2.25223079e-02,  -3.76603395e-01,  -3.72446418e-01,
         -5.44973463e-02,  -3.04856654e-02,  -8.12882781e-01,
         -6.35300994e-01,   1.01717256e-01,   1.15769980e-02,
          1.94745436e-01,  -4.62203443e-01,  -1.94413647e-01,
         -1.19787067e-01,   5.01835823e-01,  -1.22532628e-01,
         -4.83275265e-01,  -5.72950900e-01,  -1.68230399e-01,
         -2.53478941e-02,  -8.93718377e-02,  -2.09907755e-01,
          1.15736432e-01,   7.35889524e-02,  -2.25963101e-01,
         -1.25411734e-01,  -1.58686683e-01,   3.05348307e-01,
         -4.07805927e-02,  -6.87129676e-01,  -1.78614125e-01,
         -6.12517297e-02,  -1.26590893e-01,  -5.44444025e-01,
         -2.87909880e-02,  -1.61622658e-01,  -6.28022432e-01,
         -3.93144011e-01,  -4.14166540e-01,  -3.36472809e-01,
         -2.14290902e-01,  -1.57012552e-01,  -6.99233487e-02,
         -1.79140717e-01,  -3.44865173e-01,  -4.32067961e-01,
         -4.17658724e-02,  -1.92612112e-01,  -4.07513529e-01,
         -2.00688168e-01,  -3.12940218e-02,  -5.83245270e-02,
         -3.02525491e-01,  -6.36755228e-01,  -2.01398991e-02,
         -1.94140598e-01,  -5.85560381e-01,  -2.78204322e-01,
         -4.92228866e-01,   2.85394281e-01,  -5.29185772e-01,
         -5.80944479e-01,  -4.82267290e-01,  -3.02456468e-01,
         -2.17350312e-02,  -2.27617443e-01,  -8.41379631e-03,
         -5.19459188e-01,  -1.92483932e-01,  -6.69973344e-02,
         -3.18294495e-01,  -4.43626344e-01,   1.03083804e-01,
         -1.43494621e-01,  -3.98965865e-01,  -2.91880131e-01,
         -1.15407094e-01,  -2.33865350e-01,  -3.48333865e-01,
         -3.13846886e-01,  -2.00329088e-02,  -2.08419889e-01,
         -6.56257868e-02,  -3.15933287e-01,  -2.66032100e-01,
         -2.17209011e-01,  -2.57886738e-01,  -3.74219060e-01,
         -3.42252910e-01,  -3.02372843e-01,  -2.70351022e-01,
         -4.19028729e-01,  -2.16944158e-01,   1.65465083e-02,
         -1.38239786e-01,   8.82068649e-03,  -5.47306299e-01,
         -6.58184737e-02,  -1.07372276e-01,  -1.99595578e-02,
         -3.04633468e-01,  -2.42436364e-01,  -9.85036939e-02,
          8.13045427e-02,  -6.01692021e-01,  -7.83374131e-01,
         -3.54873002e-01,  -1.54401422e-01,  -1.99920405e-02,
         -6.02073036e-03,  -7.46182263e-01,  -5.17743170e-01,
         -1.43411651e-01,   1.35698587e-01,  -4.32992607e-01,
         -3.22256982e-01,   2.01625749e-01,  -1.68692529e-01,
          9.03868079e-02,  -7.36883581e-02,  -2.26779003e-02,
          7.53887817e-02,  -3.51618379e-01,  -6.96502507e-01,
         -1.97232455e-01,  -2.19720408e-01,  -1.76197141e-01,
         -3.31067145e-01,   2.52920628e-01,  -5.32557011e-01,
         -9.84433852e-03,  -2.28284430e-02,  -2.18466327e-01,
         -2.50813589e-02,  -1.22822799e-01,  -6.21357895e-02,
         -1.85140949e-02,   1.55188337e-01,  -2.91802138e-01,
         -1.76329892e-02,  -3.60844210e-02,  -5.81378281e-01,
         -6.11039221e-01,  -3.28095675e-01,  -2.83731908e-01,
         -1.66193381e-01,   5.52292354e-02,   6.29878119e-02,
         -3.41305107e-01,  -1.39835373e-01,   1.71938047e-01,
         -1.84613727e-02,   7.50863180e-02,  -3.44148017e-02,
         -3.53854299e-01,  -5.12476027e-01,   1.22042328e-01,
         -5.39535470e-02,   3.05281021e-03,  -1.19409911e-01,
         -2.89323032e-01,  -6.71940520e-02,  -2.19452642e-02,
         -2.90004104e-01,  -1.76387712e-01,  -4.56134796e-01,
         -8.09880495e-01,  -1.83778346e-01,  -2.31890544e-01,
         -4.52327728e-01,  -2.06816241e-01,  -1.38748497e-01,
         -4.18441355e-01,  -5.38856745e-01,  -5.05130768e-01,
         -1.75971299e-01,  -1.19080685e-01,  -9.46213081e-02,
         -3.64823714e-02,  -3.22997957e-01,  -1.34447142e-01,
         -1.27073288e-01,   1.64654911e-01,  -9.78678912e-02,
         -4.47389364e-01,  -2.54144296e-02,   1.73969138e-02,
         -2.04480872e-01,  -4.30503398e-01,  -1.67036086e-01,
         -2.49711365e-01,  -3.37412119e-01,  -6.02359474e-01,
         -6.62094355e-01,  -1.16948448e-01,   9.77696292e-03,
         -5.21902740e-01,  -2.33485606e-02,  -6.64649755e-02,
         -6.00027978e-01,  -5.42070754e-02,  -2.38561943e-01,
         -4.47000265e-01,   1.17274612e-01,  -1.11540303e-01,
         -1.02203742e-01,  -6.74192980e-02,  -1.72974497e-01,
         -2.43933983e-02,  -2.18470603e-01,  -1.02555685e-01,
         -5.01730680e-01,  -1.63745075e-01,  -2.48166338e-01,
          4.25796956e-02,  -8.81046131e-02,  -4.94634926e-01,
         -2.48743445e-01,   8.22583865e-03,  -2.14855313e-01,
         -5.94667614e-01,   1.23224966e-01,  -2.28983104e-01,
         -4.89580818e-02,  -3.53976309e-01,  -1.02518976e-01,
         -2.80924350e-01,   2.18932718e-01,  -9.42684943e-04,
         -2.78814733e-01,  -2.43697301e-01,  -4.07780051e-01,
         -1.57622676e-02,  -4.32732075e-01,   2.76384447e-02,
         -2.56971091e-01,  -1.39276221e-01,  -2.89412320e-01,
         -7.84103293e-03,  -5.75612962e-01,  -2.65779234e-02,
         -2.83633530e-01,  -2.42152084e-02,  -3.54716778e-01,
         -5.25303543e-01,  -6.30853772e-02,  -2.22892091e-01,
         -3.32897723e-01,  -8.58137235e-02,  -1.35768950e-01,
         -4.00102228e-01,  -6.81776628e-02,  -1.11637965e-01,
          8.71941745e-02,   7.97185600e-02,  -4.74733919e-01,
         -5.36120776e-03,  -2.00053956e-02,   2.74125468e-02,
         -5.23373425e-01,  -3.52810740e-01,  -5.75067937e-01,
         -1.27765425e-02,  -2.41196215e-01,   1.35370884e-02,
         -3.42776716e-01,  -2.61937886e-01,  -1.73471346e-01,
         -7.74265826e-01,  -3.25414896e-01,  -6.52070194e-02,
         -1.75177939e-02,  -2.78512776e-01,  -1.26804650e-01,
         -1.54330492e-01,  -2.43354395e-01,  -5.10048628e-01,
         -5.22104055e-02,  -4.48061913e-01,  -2.54915148e-01,
         -3.71145964e-01,  -2.34785691e-01,  -5.76828778e-01,
         -5.20584345e-01,  -2.01370478e-01,  -3.43574703e-01,
         -3.95394504e-01,  -7.02085435e-01,   3.80159239e-03,
         -5.05006194e-01,  -6.66690245e-02,  -2.13820174e-01,
         -1.86356172e-01,  -1.98591515e-01,  -2.26664558e-01,
         -9.84562710e-02,   9.10461769e-02,  -1.63858235e-01,
         -6.71461642e-01,  -2.07045935e-02,  -1.84064224e-01,
         -1.52253630e-02,  -6.44623414e-02,  -1.90693051e-01,
         -3.26317549e-01,  -3.90465967e-02,  -4.31612767e-02,
         -2.69320831e-02,  -2.61054486e-01,  -5.56032240e-01,
         -1.39396250e-01,  -3.04626554e-01,  -4.00418974e-02,
         -5.22964954e-01,  -2.74515212e-01,  -2.05182180e-01,
         -4.55017984e-01,  -4.10655349e-01,  -3.91681463e-01,
         -2.95707285e-01,  -1.75162852e-02,  -1.80232033e-01,
         -9.38054398e-02,  -4.48614866e-01,  -1.20916396e-01,
         -1.26026660e-01,  -6.13098264e-01,  -9.16779786e-02,
         -1.24931745e-01,  -1.14639051e-01,  -5.89349389e-01,
         -2.86892831e-01,  -4.32475626e-01,  -4.53839451e-01,
         -5.40873766e-01,  -3.22011739e-01,  -1.04171380e-01,
         -2.03116417e-01,  -7.34383706e-03,  -2.95767933e-01,
          3.77100818e-02,  -3.95163864e-01,  -9.11748350e-01,
         -2.14269429e-01,  -4.47106093e-01,  -1.02919694e-02,
         -1.46425188e-01,   1.30215868e-01,   3.46448004e-01,
         -7.53604919e-02,  -3.68188143e-01,  -1.75004661e-01,
         -3.42096955e-01,  -1.19322361e-02,   9.38493479e-03,
         -5.18787801e-01,  -1.09108455e-01,   6.15557991e-02,
         -8.33496079e-03,  -6.41730651e-02,  -1.36719868e-02,
         -3.73748362e-01,  -3.73859495e-01,   2.80248914e-02,
         -3.09117913e-01,  -2.88713902e-01,  -4.28494245e-01,
         -5.13740003e-01,  -1.57594740e-01,  -4.70732421e-01,
         -1.38654308e-02,  -6.85215056e-01,  -3.66586596e-01,
         -1.41351402e-01,  -1.13854766e-01,  -5.36643863e-01,
         -4.75565642e-01,  -5.00832915e-01,  -4.08477843e-01,
         -3.66504490e-01,  -1.15367234e-01,  -2.48915218e-02,
         -4.96757418e-01,   1.17366053e-01,  -2.26039514e-01,
         -5.49678802e-01,  -2.75789142e-01,  -5.08426309e-01,
          1.07284091e-01,  -2.54364550e-01,  -3.72139484e-01,
         -3.34391892e-01,   2.10764147e-02,  -1.33560911e-01,
         -9.50245783e-02,  -3.13357562e-01,  -2.62188077e-01,
         -5.32095313e-01,  -5.31459413e-03,  -3.21489833e-02,
         -7.84164011e-01,  -1.10715240e-01,  -2.87352562e-01,
         -5.71807444e-01,  -2.04134420e-01,   7.85130933e-02,
         -3.69185776e-01,  -1.98006928e-02,   6.63151639e-03,
         -2.87224799e-01,   5.36596589e-02,  -7.96930939e-02,
         -2.82612413e-01,  -1.87133670e-01,  -6.54792845e-01,
         -8.59472081e-02,  -1.13062121e-01,  -1.83315545e-01,
         -2.58277714e-01,  -5.51701725e-01,  -5.59242129e-01,
         -1.50169775e-01,   4.73141856e-02,  -1.68764800e-01,
         -2.75284111e-01,  -4.43699747e-01,  -2.76820183e-01,
         -3.51191200e-02,  -1.07176892e-01,  -4.73967902e-02,
         -4.53751475e-01,  -2.84370124e-01,  -4.89342690e-01,
         -3.81000303e-02,  -5.29655755e-01,  -1.50656566e-01,
         -4.64593619e-01,  -1.58045471e-01,  -7.06188157e-02,
         -4.04648870e-01,  -3.15317452e-01,  -2.87708908e-01,
         -1.71832666e-01,  -2.27938369e-01,  -2.11054739e-02,
         -3.29687774e-01,  -1.82581544e-01,  -2.17228252e-02,
          2.08218992e-02,  -1.46109968e-01,  -7.96382129e-02,
         -3.17795098e-01,  -5.75634658e-01,  -3.44916396e-02,
         -4.36014533e-01,  -2.85244137e-02,  -5.68732560e-01,
         -5.59068859e-01,  -1.22407533e-01,  -2.56792486e-01,
         -2.97368616e-01,  -3.03129584e-01,  -1.62084669e-01,
         -2.64727145e-01,  -4.05563980e-01,   3.00995618e-01,
         -1.86940640e-01,  -9.05097499e-02,  -1.19438395e-01,
         -1.88409179e-01,  -3.68620992e-01,   3.19603570e-02,
         -5.20787895e-01,  -2.95364499e-01,  -1.96136490e-01,
          1.30156171e+00,  -3.09764799e-02,  -1.63758829e-01,
         -1.63395420e-01,  -1.06308326e-01,  -3.37606370e-01,
         -4.02779371e-01,  -1.04163669e-01,  -3.29879135e-01,
         -6.24738149e-02,   7.57394284e-02,  -6.51596487e-01,
         -2.37611696e-01,  -5.25772333e-01,   1.44061729e-01,
         -2.59940475e-01,  -2.72920489e-01,  -3.10522407e-01,
         -8.48866284e-01,  -5.29746771e-01,  -1.75354518e-02,
         -8.73476788e-02,  -4.62230533e-01,  -3.12623024e-01,
         -4.66565102e-01,  -2.35941991e-01,  -4.72842991e-01,
         -8.59152302e-02,  -3.31128508e-01,  -1.34016275e-01,
         -6.82140663e-02,  -1.31053597e-01,   3.27668451e-02,
         -4.59252357e-01,  -7.40645081e-02,  -2.32884094e-01,
         -2.48913141e-03,  -5.38118541e-01,  -6.48121983e-02,
         -2.82097995e-01,  -4.83397216e-01,  -3.75957131e-01,
         -1.20243065e-01,  -2.91992631e-02,  -2.34807402e-01,
         -8.57004896e-02,  -1.76332936e-01,  -4.79596853e-01,
         -3.59954983e-01,  -3.86393666e-01,  -1.49604112e-01,
          9.89474952e-02,  -1.43513409e-02,  -5.00253379e-01,
         -2.31766224e-01,  -2.78296471e-01,  -1.47517323e-01,
         -2.70760179e-01,   5.62180728e-02,   1.26814142e-01,
         -2.58570649e-02,  -3.02321255e-01,  -5.06240189e-01,
         -3.60810488e-01,  -1.61365643e-01,  -1.28059566e-01,
         -2.62734950e-01,  -1.67697724e-02,   9.22571719e-02,
         -7.30941415e-01,  -3.17986846e-01,  -3.49215209e-01,
         -4.75899428e-01,  -5.54573357e-01,  -2.22814456e-01,
         -9.33618564e-03,  -4.88777943e-02,  -2.79946309e-02,
         -2.43498668e-01,   1.63741887e-01,  -8.86490270e-02,
         -1.80582032e-02,   5.81286959e-02,  -5.06547272e-01,
         -2.36781448e-01,  -2.82066971e-01,   3.62231545e-02,
          5.59952706e-02,  -5.27004182e-01,  -5.63789010e-02,
         -6.33812070e-01,  -7.20118701e-01,  -3.27905029e-01,
         -1.09615184e-01,  -1.97968498e-01,  -3.48774903e-02,
         -4.36178327e-01,  -1.90760285e-01,  -2.00712010e-01,
         -4.05785292e-02,  -7.98018798e-02,  -6.48312092e-01,
         -5.16030610e-01,  -1.82418972e-02,  -3.22774321e-01,
         -1.91510841e-01,  -1.31354675e-01,  -5.67911983e-01,
         -4.27046567e-01,  -2.61492878e-01,  -7.63690919e-02,
         -3.53502780e-01,  -2.86672637e-02,   6.57036155e-02,
         -2.32697666e-01,  -2.25740999e-01,  -2.21521795e-01,
          3.64017077e-02,  -4.65820670e-01,  -1.67809874e-01,
         -2.34040041e-02,  -3.40095460e-01,   5.10562137e-02,
         -2.80955017e-01,   2.17410009e-02,  -2.25610495e-01,
         -2.61850543e-02,  -1.18860357e-01,   9.67218876e-02,
         -6.98161423e-01,  -4.03901875e-01,  -2.49750782e-02,
         -1.49894670e-01,  -1.55417640e-02,  -2.35045440e-02,
         -1.22158304e-02,  -3.60701740e-01,  -5.72664201e-01,
         -4.56410229e-01,  -9.86423045e-02,  -5.59065938e-01,
         -2.43323550e-01,   1.14932351e-01,  -1.32146357e-02,
         -1.13701306e-01,  -2.43878905e-02,   3.04878563e-01,
         -2.93137670e-01,  -4.26690668e-01,  -1.90759376e-01,
         -5.80423713e-01,   1.61198322e-02,  -3.25486124e-01,
         -3.21475148e-01,  -2.53617167e-01,  -1.20874017e-01,
         -4.76823658e-01,  -3.47528964e-01,  -2.89901286e-01,
          2.24457998e-02,  -4.97344643e-01,   1.08718812e+00,
         -2.79220223e-01], dtype=float32),
 array([[ 0.03900816,  0.00785677, -0.06511776, ...,  0.00776991,
         -0.05963232, -0.05985177],
        [-0.20750827,  0.08817152,  0.40323174, ...,  0.20854132,
         -0.11089708,  0.14705186],
        [-0.24851227,  0.36102909,  0.07329425, ...,  0.12305254,
          0.02824712,  0.2746895 ],
        ..., 
        [-0.27076459,  0.04397521,  0.10150083, ..., -0.02952144,
          0.35495111,  0.01788467],
        [-0.22880824, -0.14765862, -0.01148497, ..., -0.04802479,
         -0.11898327,  0.16021334],
        [-0.01458607,  0.51388001,  0.25630933, ...,  0.10885861,
         -0.15997633,  0.01113635]], dtype=float32),
 array([-0.36252829, -0.41307127, -0.37561458, -0.790694  , -0.7867986 ,
        -0.39656818, -0.49989551, -0.56961799, -0.67535901, -0.78190619,
        -0.64679927, -0.62336636, -0.73334086, -0.51707494, -0.80007225,
        -0.57039291, -0.43117863, -0.57423478, -1.01204598, -0.99576569,
        -0.45388478, -0.9715423 , -0.57562113, -0.85434681, -0.4783178 ,
        -0.65333492, -0.56394655, -0.51519966, -0.87941819, -0.9431147 ,
        -0.52889907, -0.51141596, -1.04037309, -0.87605566, -0.5586676 ,
        -0.67145008, -0.62178028, -0.74712718, -0.47700772, -0.81794   ,
        -0.94796181, -1.03332078, -0.99911004, -0.35762793, -0.41830212,
        -0.44990394, -0.54796964, -0.64622766, -0.36980084, -0.62949306,
        -0.73081511, -0.92071664, -0.96040893, -0.17141432, -0.50711352,
        -0.68742466, -0.58205402, -0.60873783, -0.51237881, -0.42307621,
        -0.59278268, -0.77905166, -0.70859444, -0.99470675, -0.68357819,
        -0.45728955, -0.98573047, -0.7740072 , -0.76561183, -0.38337517,
        -0.78785807, -0.9682638 , -0.41092423, -0.81709141, -0.4595961 ,
        -0.45476505, -0.89052409, -0.95178139, -0.920165  , -0.83498871,
        -0.54309958, -0.62142682, -0.10648966, -0.55824465, -0.51698029,
        -0.65391433, -0.73073816, -0.63968295, -0.73563075, -0.37823838,
        -0.83874625, -0.35336301, -0.72945499, -0.61786187, -1.04557991,
        -0.58565521, -0.35223064, -0.30662736, -0.66361117, -0.74605358,
        -0.79575521, -1.12011874, -0.65195775, -0.66316205, -0.30292839,
        -0.97478765, -0.30300212, -0.98781288, -0.88087404, -0.56088251,
        -0.82704026, -0.57432526, -0.44808209, -0.65736598, -0.7800023 ,
        -0.43863136, -0.71997589, -0.79668957, -0.58597511, -0.79392022,
        -0.91689253, -0.17079359, -0.70273119, -0.31935337, -0.99297088,
        -1.21429086, -0.54536754, -0.66847122, -1.0803057 , -0.02116329,
        -0.36946481, -0.78094089, -0.67028719, -0.63478422, -0.56762469,
        -0.59048861, -0.40834036, -0.76510531, -0.86944491, -0.26183733,
        -0.64363545, -0.21043499, -0.80520427, -0.98543239, -1.02239132,
        -0.87130302, -1.06532812, -0.47601402, -0.55352145, -0.75008106,
        -0.57477021, -0.73686802, -0.44472244, -0.64302158, -0.61648601,
        -1.09791934, -0.83204991, -0.40939972, -0.82405424, -0.57132626,
        -0.85813493, -0.84275389, -0.53043413, -1.03980398, -0.41696942,
        -0.99465734, -0.70751721, -0.94126099, -0.70646006, -0.85644752,
        -0.75323451, -0.62099051, -0.99225199, -0.81427616, -0.72105873,
        -0.3865678 , -0.71929121, -0.85359961, -0.47467613, -0.49992275,
        -0.78395241, -0.66783226, -0.85084015, -0.37230313, -0.74241304,
        -0.52368313, -0.57518154, -0.88761586, -0.78079957, -0.84552658,
        -0.60064358, -0.58771318, -0.68866116, -0.7030834 , -0.8059988 ,
        -0.71570534, -0.56441271, -0.89694452, -0.83912975, -0.46641162], dtype=float32),
 array([[-0.78751951,  0.02826324, -0.07172652, ..., -0.27620244,
         -0.47863257, -0.49731782],
        [-0.49682441,  0.04474993, -0.77598727, ..., -0.54524791,
         -0.21792939, -0.47720003],
        [-0.2323969 , -0.88028777, -0.2349651 , ..., -0.14491257,
         -0.17279406, -0.64144588],
        ..., 
        [-0.7111882 , -0.30641097, -0.66904122, ..., -0.0798426 ,
         -0.57756215, -0.08725328],
        [ 0.11830693,  0.07352046,  0.08562858, ...,  0.09446803,
         -0.41451645, -0.35526502],
        [-0.92134595,  0.0993112 , -0.0636774 , ..., -0.0216356 ,
         -0.54615569, -0.05519475]], dtype=float32),
 array([-0.28950188, -0.33981469, -0.49054769, -0.24692491, -0.54108179,
        -0.53850734, -0.51629019, -0.45034203,  0.94987106,  0.34385717], dtype=float32)]

In [20]:
print("Training time: " + str(trainer.get_training_time()))
print("Accuracy: " + str(evaluate_accuracy(trained_model, test_set)))


Training time: 22619.2383449
Accuracy: 0.9859