In [1]:
%run init.ipynb


Using TensorFlow backend.
matchzoo version 2.1.0

data loading ...
data loaded as `train_pack_raw` `dev_pack_raw` `test_pack_raw`
`ranking_task` initialized with metrics [normalized_discounted_cumulative_gain@3(0.0), normalized_discounted_cumulative_gain@5(0.0), mean_average_precision(0.0)]
loading embedding ...
embedding loaded as `glove_embedding`

In [2]:
preprocessor = mz.preprocessors.BasicPreprocessor(fixed_length_left=10, 
                                                  fixed_length_right=100, 
                                                  remove_stop_words=False)

In [3]:
train_pack_processed = preprocessor.fit_transform(train_pack_raw)
dev_pack_processed = preprocessor.transform(dev_pack_raw)
test_pack_processed = preprocessor.transform(test_pack_raw)


Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 9365.13it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 18841/18841 [00:03<00:00, 5163.33it/s]
Processing text_right with append: 100%|██████████| 18841/18841 [00:00<00:00, 848689.58it/s]
Building FrequencyFilter from a datapack.: 100%|██████████| 18841/18841 [00:00<00:00, 151196.34it/s]
Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 89439.89it/s]
Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 706219.56it/s]
Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 736492.25it/s]
Building Vocabulary from a datapack.: 100%|██████████| 404432/404432 [00:00<00:00, 3032613.85it/s]
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 9550.46it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 18841/18841 [00:03<00:00, 5185.71it/s]
Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 144261.80it/s]
Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 239784.49it/s]
Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 131422.51it/s]
Processing length_left with len: 100%|██████████| 2118/2118 [00:00<00:00, 620662.05it/s]
Processing length_right with len: 100%|██████████| 18841/18841 [00:00<00:00, 771237.80it/s]
Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 140320.27it/s]
Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 95088.12it/s]
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 122/122 [00:00<00:00, 9513.19it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 1115/1115 [00:00<00:00, 5119.33it/s]
Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 10854.85it/s]
Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 113134.00it/s]
Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 129446.66it/s]
Processing length_left with len: 100%|██████████| 122/122 [00:00<00:00, 185736.87it/s]
Processing length_right with len: 100%|██████████| 1115/1115 [00:00<00:00, 583778.42it/s]
Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 90215.99it/s]
Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 93243.92it/s]
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 237/237 [00:00<00:00, 8866.58it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2300/2300 [00:00<00:00, 5131.75it/s]
Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 147786.31it/s]
Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 119176.36it/s]
Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 111582.90it/s]
Processing length_left with len: 100%|██████████| 237/237 [00:00<00:00, 354132.54it/s]
Processing length_right with len: 100%|██████████| 2300/2300 [00:00<00:00, 543579.15it/s]
Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 103730.57it/s]
Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 89680.20it/s]

In [4]:
preprocessor.context


Out[4]:
{'filter_unit': <matchzoo.preprocessors.units.frequency_filter.FrequencyFilter at 0x7efc21e7ce48>,
 'vocab_unit': <matchzoo.preprocessors.units.vocabulary.Vocabulary at 0x7efc16b9ec88>,
 'vocab_size': 16674,
 'embedding_input_dim': 16674,
 'input_shapes': [(10,), (100,)]}

In [5]:
model = mz.models.ArcII()

# load `input_shapes` and `embedding_input_dim` (vocab_size)
model.params.update(preprocessor.context)

model.params['task'] = ranking_task
model.params['embedding_output_dim'] = 100
model.params['embedding_trainable'] = True
model.params['num_blocks'] = 2
model.params['kernel_1d_count'] = 32
model.params['kernel_1d_size'] = 3
model.params['kernel_2d_count'] = [64, 64]
model.params['kernel_2d_size'] = [3, 3]
model.params['pool_2d_size'] = [[3, 3], [3, 3]]
model.params['optimizer'] = 'adam'

model.build()
model.compile()

print(model.params)


model_class                   <class 'matchzoo.models.arcii.ArcII'>
input_shapes                  [(10,), (100,)]
task                          Ranking Task
optimizer                     adam
with_embedding                True
embedding_input_dim           16674
embedding_output_dim          100
embedding_trainable           True
num_blocks                    2
kernel_1d_count               32
kernel_1d_size                3
kernel_2d_count               [64, 64]
kernel_2d_size                [3, 3]
activation                    relu
pool_2d_size                  [[3, 3], [3, 3]]
padding                       same
dropout_rate                  0.0

In [6]:
model.backend.summary()


__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
text_left (InputLayer)          (None, 10)           0                                            
__________________________________________________________________________________________________
text_right (InputLayer)         (None, 100)          0                                            
__________________________________________________________________________________________________
embedding (Embedding)           multiple             1667400     text_left[0][0]                  
                                                                 text_right[0][0]                 
__________________________________________________________________________________________________
conv1d_1 (Conv1D)               (None, 10, 32)       9632        embedding[0][0]                  
__________________________________________________________________________________________________
conv1d_2 (Conv1D)               (None, 100, 32)      9632        embedding[1][0]                  
__________________________________________________________________________________________________
matching_layer_1 (MatchingLayer (None, 10, 100, 32)  0           conv1d_1[0][0]                   
                                                                 conv1d_2[0][0]                   
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 10, 100, 64)  18496       matching_layer_1[0][0]           
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 3, 33, 64)    0           conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 3, 33, 64)    36928       max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, 1, 11, 64)    0           conv2d_2[0][0]                   
__________________________________________________________________________________________________
flatten_1 (Flatten)             (None, 704)          0           max_pooling2d_2[0][0]            
__________________________________________________________________________________________________
dropout_1 (Dropout)             (None, 704)          0           flatten_1[0][0]                  
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 1)            705         dropout_1[0][0]                  
==================================================================================================
Total params: 1,742,793
Trainable params: 1,742,793
Non-trainable params: 0
__________________________________________________________________________________________________

In [7]:
embedding_matrix = glove_embedding.build_matrix(preprocessor.context['vocab_unit'].state['term_index'])

In [8]:
model.load_embedding_matrix(embedding_matrix)

In [9]:
test_x, test_y = test_pack_processed[:].unpack()
evaluate = mz.callbacks.EvaluateAllMetrics(model, x=test_x, y=test_y, batch_size=len(test_y))

In [10]:
train_generator = mz.DataGenerator(
    train_pack_processed,
    mode='pair',
    num_dup=2,
    num_neg=1,
    batch_size=20
)
print('num batches:', len(train_generator))


num batches: 102

In [11]:
history = model.fit_generator(train_generator, 
                              epochs=30, 
                              callbacks=[evaluate], 
                              workers=30, 
                              use_multiprocessing=True)


Epoch 1/30
102/102 [==============================] - 6s 61ms/step - loss: 0.6123
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6165843661519929 - normalized_discounted_cumulative_gain@5(0.0): 0.6639951229938149 - mean_average_precision(0.0): 0.6255013171310638
Epoch 2/30
102/102 [==============================] - 11s 107ms/step - loss: 0.3213
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5795022399236163 - normalized_discounted_cumulative_gain@5(0.0): 0.6400033764936961 - mean_average_precision(0.0): 0.5927358064245385
Epoch 3/30
102/102 [==============================] - 12s 116ms/step - loss: 0.1556
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5734366624818106 - normalized_discounted_cumulative_gain@5(0.0): 0.6320835504730066 - mean_average_precision(0.0): 0.5873376891991933
Epoch 4/30
102/102 [==============================] - 12s 121ms/step - loss: 0.0966
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6047140052395196 - normalized_discounted_cumulative_gain@5(0.0): 0.6646231653619664 - mean_average_precision(0.0): 0.6177779169838045
Epoch 5/30
102/102 [==============================] - 12s 117ms/step - loss: 0.0688
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5804830097430296 - normalized_discounted_cumulative_gain@5(0.0): 0.6397107091735946 - mean_average_precision(0.0): 0.5861104257070671
Epoch 6/30
102/102 [==============================] - 12s 119ms/step - loss: 0.0741
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5733720686699197 - normalized_discounted_cumulative_gain@5(0.0): 0.6297166660287926 - mean_average_precision(0.0): 0.5793064781802681
Epoch 7/30
102/102 [==============================] - 12s 118ms/step - loss: 0.0545
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5612000374349604 - normalized_discounted_cumulative_gain@5(0.0): 0.6275035028226839 - mean_average_precision(0.0): 0.5812912862211518
Epoch 8/30
102/102 [==============================] - 13s 129ms/step - loss: 0.0351
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5695536121652267 - normalized_discounted_cumulative_gain@5(0.0): 0.6230819298416224 - mean_average_precision(0.0): 0.5829102115308369
Epoch 9/30
102/102 [==============================] - 13s 127ms/step - loss: 0.0546
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5747202712592626 - normalized_discounted_cumulative_gain@5(0.0): 0.6283361484936623 - mean_average_precision(0.0): 0.5858063874997032
Epoch 10/30
102/102 [==============================] - 12s 120ms/step - loss: 0.0363
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5608760792872554 - normalized_discounted_cumulative_gain@5(0.0): 0.6165261187651955 - mean_average_precision(0.0): 0.5804023750891116
Epoch 11/30
102/102 [==============================] - 12s 121ms/step - loss: 0.0316
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5526488015162425 - normalized_discounted_cumulative_gain@5(0.0): 0.6190121151432394 - mean_average_precision(0.0): 0.5724829930234496
Epoch 12/30
102/102 [==============================] - 12s 122ms/step - loss: 0.0286
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5574638704067251 - normalized_discounted_cumulative_gain@5(0.0): 0.6221425949954883 - mean_average_precision(0.0): 0.5790691069305163
Epoch 13/30
102/102 [==============================] - 14s 135ms/step - loss: 0.0232
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5530705667805896 - normalized_discounted_cumulative_gain@5(0.0): 0.6045382570104455 - mean_average_precision(0.0): 0.560135777928777
Epoch 14/30
102/102 [==============================] - 12s 122ms/step - loss: 0.0232
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5551975290816926 - normalized_discounted_cumulative_gain@5(0.0): 0.6199269060404192 - mean_average_precision(0.0): 0.5695024812501234
Epoch 15/30
102/102 [==============================] - 12s 120ms/step - loss: 0.0186
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5798942973253254 - normalized_discounted_cumulative_gain@5(0.0): 0.6326798983491299 - mean_average_precision(0.0): 0.5914300748272545
Epoch 16/30
102/102 [==============================] - 12s 120ms/step - loss: 0.0224
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5549313451763698 - normalized_discounted_cumulative_gain@5(0.0): 0.609066461479371 - mean_average_precision(0.0): 0.5696612552826874
Epoch 17/30
102/102 [==============================] - 14s 134ms/step - loss: 0.0320
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5362590661756249 - normalized_discounted_cumulative_gain@5(0.0): 0.5960587553325881 - mean_average_precision(0.0): 0.5529867856626977
Epoch 18/30
102/102 [==============================] - 13s 123ms/step - loss: 0.0231
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5709133081694716 - normalized_discounted_cumulative_gain@5(0.0): 0.6294306655133045 - mean_average_precision(0.0): 0.5833828714158493
Epoch 19/30
102/102 [==============================] - 12s 120ms/step - loss: 0.0210
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5408055010433294 - normalized_discounted_cumulative_gain@5(0.0): 0.6188813340014495 - mean_average_precision(0.0): 0.5639203433881982
Epoch 20/30
102/102 [==============================] - 11s 111ms/step - loss: 0.0099
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5857209979914645 - normalized_discounted_cumulative_gain@5(0.0): 0.6369323440919344 - mean_average_precision(0.0): 0.59214541005708
Epoch 21/30
102/102 [==============================] - 11s 108ms/step - loss: 0.0314
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5621754667046626 - normalized_discounted_cumulative_gain@5(0.0): 0.6250852206304124 - mean_average_precision(0.0): 0.5746827039867791
Epoch 22/30
102/102 [==============================] - 11s 108ms/step - loss: 0.0346
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5245291889962129 - normalized_discounted_cumulative_gain@5(0.0): 0.5957453960380752 - mean_average_precision(0.0): 0.5563949155969331
Epoch 23/30
102/102 [==============================] - 11s 104ms/step - loss: 0.0180
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5649023738386006 - normalized_discounted_cumulative_gain@5(0.0): 0.6281559244518438 - mean_average_precision(0.0): 0.5862423847100566
Epoch 24/30
102/102 [==============================] - 11s 110ms/step - loss: 0.0326
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5595631369578311 - normalized_discounted_cumulative_gain@5(0.0): 0.6216538205632761 - mean_average_precision(0.0): 0.5770444731020851
Epoch 25/30
102/102 [==============================] - 11s 107ms/step - loss: 0.0212
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5611397782325962 - normalized_discounted_cumulative_gain@5(0.0): 0.6275063453800273 - mean_average_precision(0.0): 0.5817784786583774
Epoch 26/30
102/102 [==============================] - 11s 110ms/step - loss: 0.0297
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5332907973060383 - normalized_discounted_cumulative_gain@5(0.0): 0.5988259655236242 - mean_average_precision(0.0): 0.5647126123155829
Epoch 27/30
102/102 [==============================] - 11s 108ms/step - loss: 0.0244
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5448159675702459 - normalized_discounted_cumulative_gain@5(0.0): 0.5945159262523649 - mean_average_precision(0.0): 0.5558586926930239
Epoch 28/30
102/102 [==============================] - 11s 108ms/step - loss: 0.0157
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.570165361029481 - normalized_discounted_cumulative_gain@5(0.0): 0.6316999513916105 - mean_average_precision(0.0): 0.5947381662168211
Epoch 29/30
102/102 [==============================] - 11s 107ms/step - loss: 0.0360
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.586639186111628 - normalized_discounted_cumulative_gain@5(0.0): 0.6352015410005084 - mean_average_precision(0.0): 0.596046073559104
Epoch 30/30
102/102 [==============================] - 11s 108ms/step - loss: 0.0149
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.57738099248289 - normalized_discounted_cumulative_gain@5(0.0): 0.63122082754378 - mean_average_precision(0.0): 0.5877516720454762
Use this function to update the README.md with a better set of parameters. Make sure you delete the correct section of the README.md before calling this function.

In [12]:
append_params_to_readme(model)