In [1]:
%run ./tutorials/wikiqa/init.ipynb


Using TensorFlow backend.
matchzoo version 2.1.0

data loading ...
data loaded as `train_pack_raw` `dev_pack_raw` `test_pack_raw`
`ranking_task` initialized with metrics [normalized_discounted_cumulative_gain@3(0.0), normalized_discounted_cumulative_gain@5(0.0), mean_average_precision(0.0)]
loading embedding ...
embedding loaded as `glove_embedding`

In [2]:
import numpy as np
import pandas as pd
from keras.optimizers import Adam
from keras.utils import to_categorical

import matchzoo as mz
from matchzoo.contrib.models.esim import ESIM

In [3]:
def load_filtered_data(preprocessor, data_type):
    assert ( data_type in ['train', 'dev', 'test'])
    data_pack = mz.datasets.wiki_qa.load_data(data_type, task='ranking')

    if data_type == 'train':
        X, Y = preprocessor.fit_transform(data_pack).unpack()
    else:
        X, Y = preprocessor.transform(data_pack).unpack()

    new_idx = []
    for i in range(Y.shape[0]):
        if X["length_left"][i] == 0 or X["length_right"][i] == 0:
            continue
        new_idx.append(i)
    new_idx = np.array(new_idx)
    print("Removed empty data. Found ", (Y.shape[0] - new_idx.shape[0]))

    for k in X.keys():
        X[k] = X[k][new_idx]
    Y = Y[new_idx]

    pos_idx = (Y == 1)[:, 0]
    pos_qid = X["id_left"][pos_idx]
    keep_idx_bool = np.array([ qid in pos_qid for qid in X["id_left"]])
    keep_idx = np.arange(keep_idx_bool.shape[0])
    keep_idx = keep_idx[keep_idx_bool]
    print("Removed questions with no pos label. Found ", (keep_idx_bool == 0).sum())

    print("shuffling...")
    np.random.shuffle(keep_idx)
    for k in X.keys():
        X[k] = X[k][keep_idx]
    Y = Y[keep_idx]

    return X, Y, preprocessor

In [4]:
fixed_length_left = 10
fixed_length_right = 40
batch_size = 32
epochs = 5

In [5]:
# prepare data
preprocessor = mz.preprocessors.BasicPreprocessor(fixed_length_left=fixed_length_left,
                                                  fixed_length_right=fixed_length_right,
                                                  remove_stop_words=False,
                                                  filter_low_freq=10)

train_X, train_Y, preprocessor = load_filtered_data(preprocessor, 'train')
val_X, val_Y, _ = load_filtered_data(preprocessor, 'dev')
pred_X, pred_Y = val_X, val_Y
# pred_X, pred_Y, _ = load_filtered_data(preprocessor, 'test') # no prediction label for quora dataset

embedding_matrix = glove_embedding.build_matrix(preprocessor.context['vocab_unit'].state['term_index'], initializer=lambda: 0)


Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 10798.93it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 18841/18841 [00:02<00:00, 8019.65it/s]
Processing text_right with append: 100%|██████████| 18841/18841 [00:00<00:00, 1415354.12it/s]
Building FrequencyFilter from a datapack.: 100%|██████████| 18841/18841 [00:00<00:00, 226166.63it/s]
Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 233892.08it/s]
Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 782897.32it/s]
Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 1175423.27it/s]
Building Vocabulary from a datapack.: 100%|██████████| 358408/358408 [00:00<00:00, 4845654.07it/s]
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 15108.05it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 18841/18841 [00:02<00:00, 8129.15it/s]
Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 222548.25it/s]
Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 324738.11it/s]
Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 122413.67it/s]
Processing length_left with len: 100%|██████████| 2118/2118 [00:00<00:00, 821484.73it/s]
Processing length_right with len: 100%|██████████| 18841/18841 [00:00<00:00, 1319786.92it/s]
Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 200871.36it/s]
Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 180842.83it/s]
Removed empty data. Found  91
Removed questions with no pos label. Found  11642
shuffling...
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 296/296 [00:00<00:00, 15853.43it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2708/2708 [00:00<00:00, 8318.22it/s]
Processing text_right with transform: 100%|██████████| 2708/2708 [00:00<00:00, 232964.32it/s]
Processing text_left with transform: 100%|██████████| 296/296 [00:00<00:00, 200892.23it/s]
Processing text_right with transform: 100%|██████████| 2708/2708 [00:00<00:00, 231808.96it/s]
Processing length_left with len: 100%|██████████| 296/296 [00:00<00:00, 562279.88it/s]
Processing length_right with len: 100%|██████████| 2708/2708 [00:00<00:00, 1159470.73it/s]
Processing text_left with transform: 100%|██████████| 296/296 [00:00<00:00, 183357.55it/s]
Processing text_right with transform: 100%|██████████| 2708/2708 [00:00<00:00, 178815.40it/s]
Removed empty data. Found  8
Removed questions with no pos label. Found  1595
shuffling...

In [6]:
model = ESIM()
model.params['task'] = mz.tasks.Ranking()
model.params['mask_value'] = 0
model.params['input_shapes'] = [[fixed_length_left, ],
                                [fixed_length_right, ]]
model.params['lstm_dim'] = 300
model.params['embedding_input_dim'] = preprocessor.context['vocab_size']
model.params['embedding_output_dim'] = 300
model.params['embedding_trainable'] = False
model.params['dropout_rate'] = 0.5

model.params['mlp_num_units'] = 300
model.params['mlp_num_layers'] = 0
model.params['mlp_num_fan_out'] = 300
model.params['mlp_activation_func'] = 'tanh'
model.params['optimizer'] = Adam(lr=4e-4)

model.guess_and_fill_missing_params()
model.build()

model.compile()
model.backend.summary() # not visualize


__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
text_left (InputLayer)          (None, 10)           0                                            
__________________________________________________________________________________________________
text_right (InputLayer)         (None, 40)           0                                            
__________________________________________________________________________________________________
embedding (Embedding)           multiple             1930500     text_left[0][0]                  
                                                                 text_right[0][0]                 
__________________________________________________________________________________________________
dropout_1 (Dropout)             multiple             0           embedding[0][0]                  
                                                                 embedding[1][0]                  
                                                                 dense_1[0][0]                    
                                                                 dense_1[1][0]                    
                                                                 dense_2[0][0]                    
__________________________________________________________________________________________________
lambda_1 (Lambda)               multiple             0           text_left[0][0]                  
                                                                 text_right[0][0]                 
__________________________________________________________________________________________________
bidirectional_1 (Bidirectional) multiple             1442400     dropout_1[0][0]                  
                                                                 dropout_1[1][0]                  
__________________________________________________________________________________________________
lambda_2 (Lambda)               (None, 10, 1)        0           lambda_1[0][0]                   
__________________________________________________________________________________________________
lambda_3 (Lambda)               (None, 40, 1)        0           lambda_1[1][0]                   
__________________________________________________________________________________________________
multiply_1 (Multiply)           (None, 10, 600)      0           bidirectional_1[0][0]            
                                                                 lambda_2[0][0]                   
__________________________________________________________________________________________________
multiply_2 (Multiply)           (None, 40, 600)      0           bidirectional_1[1][0]            
                                                                 lambda_3[0][0]                   
__________________________________________________________________________________________________
lambda_4 (Lambda)               (None, 10, 1)        0           lambda_1[0][0]                   
__________________________________________________________________________________________________
lambda_5 (Lambda)               (None, 1, 40)        0           lambda_1[1][0]                   
__________________________________________________________________________________________________
dot_1 (Dot)                     (None, 10, 40)       0           multiply_1[0][0]                 
                                                                 multiply_2[0][0]                 
__________________________________________________________________________________________________
multiply_3 (Multiply)           (None, 10, 40)       0           lambda_4[0][0]                   
                                                                 lambda_5[0][0]                   
__________________________________________________________________________________________________
permute_1 (Permute)             (None, 40, 10)       0           dot_1[0][0]                      
                                                                 multiply_3[0][0]                 
__________________________________________________________________________________________________
atten_mask (Lambda)             multiple             0           dot_1[0][0]                      
                                                                 multiply_3[0][0]                 
                                                                 permute_1[0][0]                  
                                                                 permute_1[1][0]                  
__________________________________________________________________________________________________
softmax_1 (Softmax)             multiple             0           atten_mask[0][0]                 
                                                                 atten_mask[1][0]                 
__________________________________________________________________________________________________
dot_2 (Dot)                     (None, 10, 600)      0           softmax_1[0][0]                  
                                                                 multiply_2[0][0]                 
__________________________________________________________________________________________________
dot_3 (Dot)                     (None, 40, 600)      0           softmax_1[1][0]                  
                                                                 multiply_1[0][0]                 
__________________________________________________________________________________________________
subtract_1 (Subtract)           (None, 10, 600)      0           multiply_1[0][0]                 
                                                                 dot_2[0][0]                      
__________________________________________________________________________________________________
multiply_4 (Multiply)           (None, 10, 600)      0           multiply_1[0][0]                 
                                                                 dot_2[0][0]                      
__________________________________________________________________________________________________
subtract_2 (Subtract)           (None, 40, 600)      0           multiply_2[0][0]                 
                                                                 dot_3[0][0]                      
__________________________________________________________________________________________________
multiply_5 (Multiply)           (None, 40, 600)      0           multiply_2[0][0]                 
                                                                 dot_3[0][0]                      
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 10, 2400)     0           multiply_1[0][0]                 
                                                                 dot_2[0][0]                      
                                                                 subtract_1[0][0]                 
                                                                 multiply_4[0][0]                 
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 40, 2400)     0           multiply_2[0][0]                 
                                                                 dot_3[0][0]                      
                                                                 subtract_2[0][0]                 
                                                                 multiply_5[0][0]                 
__________________________________________________________________________________________________
dense_1 (Dense)                 multiple             720300      concatenate_1[0][0]              
                                                                 concatenate_2[0][0]              
__________________________________________________________________________________________________
bidirectional_2 (Bidirectional) multiple             1442400     dropout_1[2][0]                  
                                                                 dropout_1[3][0]                  
__________________________________________________________________________________________________
lambda_6 (Lambda)               (None, 10, 1)        0           lambda_1[0][0]                   
__________________________________________________________________________________________________
lambda_8 (Lambda)               (None, 10, 1)        0           lambda_1[0][0]                   
__________________________________________________________________________________________________
lambda_10 (Lambda)              (None, 40, 1)        0           lambda_1[1][0]                   
__________________________________________________________________________________________________
lambda_12 (Lambda)              (None, 40, 1)        0           lambda_1[1][0]                   
__________________________________________________________________________________________________
multiply_6 (Multiply)           (None, 10, 600)      0           bidirectional_2[0][0]            
                                                                 lambda_6[0][0]                   
__________________________________________________________________________________________________
multiply_7 (Multiply)           (None, 10, 600)      0           bidirectional_2[0][0]            
                                                                 lambda_8[0][0]                   
__________________________________________________________________________________________________
multiply_8 (Multiply)           (None, 40, 600)      0           bidirectional_2[1][0]            
                                                                 lambda_10[0][0]                  
__________________________________________________________________________________________________
multiply_9 (Multiply)           (None, 40, 600)      0           bidirectional_2[1][0]            
                                                                 lambda_12[0][0]                  
__________________________________________________________________________________________________
lambda_7 (Lambda)               (None, 600)          0           multiply_6[0][0]                 
                                                                 lambda_6[0][0]                   
__________________________________________________________________________________________________
lambda_9 (Lambda)               (None, 600)          0           multiply_7[0][0]                 
__________________________________________________________________________________________________
lambda_11 (Lambda)              (None, 600)          0           multiply_8[0][0]                 
                                                                 lambda_10[0][0]                  
__________________________________________________________________________________________________
lambda_13 (Lambda)              (None, 600)          0           multiply_9[0][0]                 
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 1200)         0           lambda_7[0][0]                   
                                                                 lambda_9[0][0]                   
__________________________________________________________________________________________________
concatenate_4 (Concatenate)     (None, 1200)         0           lambda_11[0][0]                  
                                                                 lambda_13[0][0]                  
__________________________________________________________________________________________________
concatenate_5 (Concatenate)     (None, 2400)         0           concatenate_3[0][0]              
                                                                 concatenate_4[0][0]              
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 300)          720300      concatenate_5[0][0]              
__________________________________________________________________________________________________
dense_3 (Dense)                 (None, 1)            301         dropout_1[4][0]                  
==================================================================================================
Total params: 6,256,201
Trainable params: 4,325,701
Non-trainable params: 1,930,500
__________________________________________________________________________________________________

In [7]:
# run as classification task
model.load_embedding_matrix(embedding_matrix)
evaluate = mz.callbacks.EvaluateAllMetrics(model,
                                           x=pred_X,
                                           y=pred_Y,
                                           once_every=1,
                                           batch_size=len(pred_Y))

history = model.fit(x = [train_X['text_left'],
                         train_X['text_right']],
                    y = train_Y,
                    validation_data = (val_X, val_Y),
                    batch_size = batch_size,
                    epochs = epochs,
                    callbacks=[evaluate]
                    )


Train on 8627 samples, validate on 1130 samples
Epoch 1/5
8627/8627 [==============================] - 48s 6ms/step - loss: 0.1073 - val_loss: 0.0984
Validation: mean_average_precision(0.0): 0.6222655981584554
Epoch 2/5
8627/8627 [==============================] - 44s 5ms/step - loss: 0.0994 - val_loss: 0.0974
Validation: mean_average_precision(0.0): 0.640342571890191
Epoch 3/5
8627/8627 [==============================] - 44s 5ms/step - loss: 0.0944 - val_loss: 0.0981
Validation: mean_average_precision(0.0): 0.633281742507933
Epoch 4/5
8627/8627 [==============================] - 44s 5ms/step - loss: 0.0915 - val_loss: 0.0898
Validation: mean_average_precision(0.0): 0.6479046351993808
Epoch 5/5
8627/8627 [==============================] - 44s 5ms/step - loss: 0.0893 - val_loss: 0.0931
Validation: mean_average_precision(0.0): 0.6506805763854636

In [6]:
# run as classification task
classification_task = mz.tasks.Classification(num_classes=2)
classification_task.metrics = 'acc'

model = ESIM()
model.params['task'] = classification_task
model.params['mask_value'] = 0
model.params['input_shapes'] = [[fixed_length_left, ],
                                [fixed_length_right, ]]
model.params['lstm_dim'] = 300
model.params['embedding_input_dim'] = preprocessor.context['vocab_size']
model.params['embedding_output_dim'] = 300
model.params['embedding_trainable'] = False
model.params['dropout_rate'] = 0.5

model.params['mlp_num_units'] = 300
model.params['mlp_num_layers'] = 0
model.params['mlp_num_fan_out'] = 300
model.params['mlp_activation_func'] = 'tanh'
model.params['optimizer'] = Adam(lr=4e-4)

model.guess_and_fill_missing_params()
model.build()

model.compile()
model.backend.summary() # not visualize


__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
text_left (InputLayer)          (None, 10)           0                                            
__________________________________________________________________________________________________
text_right (InputLayer)         (None, 40)           0                                            
__________________________________________________________________________________________________
embedding (Embedding)           multiple             1930500     text_left[0][0]                  
                                                                 text_right[0][0]                 
__________________________________________________________________________________________________
dropout_1 (Dropout)             multiple             0           embedding[0][0]                  
                                                                 embedding[1][0]                  
                                                                 dense_1[0][0]                    
                                                                 dense_1[1][0]                    
                                                                 dense_2[0][0]                    
__________________________________________________________________________________________________
lambda_1 (Lambda)               multiple             0           text_left[0][0]                  
                                                                 text_right[0][0]                 
__________________________________________________________________________________________________
bidirectional_1 (Bidirectional) multiple             1442400     dropout_1[0][0]                  
                                                                 dropout_1[1][0]                  
__________________________________________________________________________________________________
lambda_2 (Lambda)               (None, 10, 1)        0           lambda_1[0][0]                   
__________________________________________________________________________________________________
lambda_3 (Lambda)               (None, 40, 1)        0           lambda_1[1][0]                   
__________________________________________________________________________________________________
multiply_1 (Multiply)           (None, 10, 600)      0           bidirectional_1[0][0]            
                                                                 lambda_2[0][0]                   
__________________________________________________________________________________________________
multiply_2 (Multiply)           (None, 40, 600)      0           bidirectional_1[1][0]            
                                                                 lambda_3[0][0]                   
__________________________________________________________________________________________________
lambda_4 (Lambda)               (None, 10, 1)        0           lambda_1[0][0]                   
__________________________________________________________________________________________________
lambda_5 (Lambda)               (None, 1, 40)        0           lambda_1[1][0]                   
__________________________________________________________________________________________________
dot_1 (Dot)                     (None, 10, 40)       0           multiply_1[0][0]                 
                                                                 multiply_2[0][0]                 
__________________________________________________________________________________________________
multiply_3 (Multiply)           (None, 10, 40)       0           lambda_4[0][0]                   
                                                                 lambda_5[0][0]                   
__________________________________________________________________________________________________
permute_1 (Permute)             (None, 40, 10)       0           dot_1[0][0]                      
                                                                 multiply_3[0][0]                 
__________________________________________________________________________________________________
atten_mask (Lambda)             multiple             0           dot_1[0][0]                      
                                                                 multiply_3[0][0]                 
                                                                 permute_1[0][0]                  
                                                                 permute_1[1][0]                  
__________________________________________________________________________________________________
softmax_1 (Softmax)             multiple             0           atten_mask[0][0]                 
                                                                 atten_mask[1][0]                 
__________________________________________________________________________________________________
dot_2 (Dot)                     (None, 10, 600)      0           softmax_1[0][0]                  
                                                                 multiply_2[0][0]                 
__________________________________________________________________________________________________
dot_3 (Dot)                     (None, 40, 600)      0           softmax_1[1][0]                  
                                                                 multiply_1[0][0]                 
__________________________________________________________________________________________________
subtract_1 (Subtract)           (None, 10, 600)      0           multiply_1[0][0]                 
                                                                 dot_2[0][0]                      
__________________________________________________________________________________________________
multiply_4 (Multiply)           (None, 10, 600)      0           multiply_1[0][0]                 
                                                                 dot_2[0][0]                      
__________________________________________________________________________________________________
subtract_2 (Subtract)           (None, 40, 600)      0           multiply_2[0][0]                 
                                                                 dot_3[0][0]                      
__________________________________________________________________________________________________
multiply_5 (Multiply)           (None, 40, 600)      0           multiply_2[0][0]                 
                                                                 dot_3[0][0]                      
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 10, 2400)     0           multiply_1[0][0]                 
                                                                 dot_2[0][0]                      
                                                                 subtract_1[0][0]                 
                                                                 multiply_4[0][0]                 
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 40, 2400)     0           multiply_2[0][0]                 
                                                                 dot_3[0][0]                      
                                                                 subtract_2[0][0]                 
                                                                 multiply_5[0][0]                 
__________________________________________________________________________________________________
dense_1 (Dense)                 multiple             720300      concatenate_1[0][0]              
                                                                 concatenate_2[0][0]              
__________________________________________________________________________________________________
bidirectional_2 (Bidirectional) multiple             1442400     dropout_1[2][0]                  
                                                                 dropout_1[3][0]                  
__________________________________________________________________________________________________
lambda_6 (Lambda)               (None, 10, 1)        0           lambda_1[0][0]                   
__________________________________________________________________________________________________
lambda_8 (Lambda)               (None, 10, 1)        0           lambda_1[0][0]                   
__________________________________________________________________________________________________
lambda_10 (Lambda)              (None, 40, 1)        0           lambda_1[1][0]                   
__________________________________________________________________________________________________
lambda_12 (Lambda)              (None, 40, 1)        0           lambda_1[1][0]                   
__________________________________________________________________________________________________
multiply_6 (Multiply)           (None, 10, 600)      0           bidirectional_2[0][0]            
                                                                 lambda_6[0][0]                   
__________________________________________________________________________________________________
multiply_7 (Multiply)           (None, 10, 600)      0           bidirectional_2[0][0]            
                                                                 lambda_8[0][0]                   
__________________________________________________________________________________________________
multiply_8 (Multiply)           (None, 40, 600)      0           bidirectional_2[1][0]            
                                                                 lambda_10[0][0]                  
__________________________________________________________________________________________________
multiply_9 (Multiply)           (None, 40, 600)      0           bidirectional_2[1][0]            
                                                                 lambda_12[0][0]                  
__________________________________________________________________________________________________
lambda_7 (Lambda)               (None, 600)          0           multiply_6[0][0]                 
                                                                 lambda_6[0][0]                   
__________________________________________________________________________________________________
lambda_9 (Lambda)               (None, 600)          0           multiply_7[0][0]                 
__________________________________________________________________________________________________
lambda_11 (Lambda)              (None, 600)          0           multiply_8[0][0]                 
                                                                 lambda_10[0][0]                  
__________________________________________________________________________________________________
lambda_13 (Lambda)              (None, 600)          0           multiply_9[0][0]                 
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 1200)         0           lambda_7[0][0]                   
                                                                 lambda_9[0][0]                   
__________________________________________________________________________________________________
concatenate_4 (Concatenate)     (None, 1200)         0           lambda_11[0][0]                  
                                                                 lambda_13[0][0]                  
__________________________________________________________________________________________________
concatenate_5 (Concatenate)     (None, 2400)         0           concatenate_3[0][0]              
                                                                 concatenate_4[0][0]              
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 300)          720300      concatenate_5[0][0]              
__________________________________________________________________________________________________
dense_3 (Dense)                 (None, 2)            602         dropout_1[4][0]                  
==================================================================================================
Total params: 6,256,502
Trainable params: 4,326,002
Non-trainable params: 1,930,500
__________________________________________________________________________________________________

In [7]:
evaluate = mz.callbacks.EvaluateAllMetrics(model,
                                           x=pred_X,
                                           y=pred_Y,
                                           once_every=1,
                                           batch_size=len(pred_Y))

train_Y = to_categorical(train_Y)
val_Y = to_categorical(val_Y)

model.load_embedding_matrix(embedding_matrix)
history = model.fit(x = [train_X['text_left'],
                         train_X['text_right']],
                    y = train_Y,
                    validation_data = (val_X, val_Y),
                    batch_size = batch_size,
                    epochs = epochs,
                    callbacks=[evaluate]
                    )


Train on 8627 samples, validate on 1130 samples
Epoch 1/5
8627/8627 [==============================] - 48s 6ms/step - loss: 0.3607 - val_loss: 0.3330
Validation: categorical_accuracy: 1.0
Epoch 2/5
8627/8627 [==============================] - 43s 5ms/step - loss: 0.3273 - val_loss: 0.3490
Validation: categorical_accuracy: 0.9451327323913574
Epoch 3/5
8627/8627 [==============================] - 44s 5ms/step - loss: 0.3096 - val_loss: 0.3498
Validation: categorical_accuracy: 0.9938052892684937
Epoch 4/5
8627/8627 [==============================] - 44s 5ms/step - loss: 0.2970 - val_loss: 0.3170
Validation: categorical_accuracy: 0.969911515712738
Epoch 5/5
8627/8627 [==============================] - 44s 5ms/step - loss: 0.2787 - val_loss: 0.3543
Validation: categorical_accuracy: 0.8778761029243469

In [14]:
model.evaluate(val_X, val_Y)


Out[14]:
{'categorical_accuracy': 0.8920354}

In [ ]: