In [1]:
%run init.ipynb
Using TensorFlow backend.
matchzoo version 2.1.0
data loading ...
data loaded as `train_pack_raw` `dev_pack_raw` `test_pack_raw`
`ranking_task` initialized with metrics [normalized_discounted_cumulative_gain@3(0.0), normalized_discounted_cumulative_gain@5(0.0), mean_average_precision(0.0)]
loading embedding ...
embedding loaded as `glove_embedding`
In [2]:
preprocessor = mz.preprocessors.BasicPreprocessor(fixed_length_left=10,
fixed_length_right=100,
remove_stop_words=False)
In [3]:
train_pack_processed = preprocessor.fit_transform(train_pack_raw)
dev_pack_processed = preprocessor.transform(dev_pack_raw)
test_pack_processed = preprocessor.transform(test_pack_raw)
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 9365.13it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 18841/18841 [00:03<00:00, 5163.33it/s]
Processing text_right with append: 100%|██████████| 18841/18841 [00:00<00:00, 848689.58it/s]
Building FrequencyFilter from a datapack.: 100%|██████████| 18841/18841 [00:00<00:00, 151196.34it/s]
Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 89439.89it/s]
Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 706219.56it/s]
Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 736492.25it/s]
Building Vocabulary from a datapack.: 100%|██████████| 404432/404432 [00:00<00:00, 3032613.85it/s]
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 9550.46it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 18841/18841 [00:03<00:00, 5185.71it/s]
Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 144261.80it/s]
Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 239784.49it/s]
Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 131422.51it/s]
Processing length_left with len: 100%|██████████| 2118/2118 [00:00<00:00, 620662.05it/s]
Processing length_right with len: 100%|██████████| 18841/18841 [00:00<00:00, 771237.80it/s]
Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 140320.27it/s]
Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 95088.12it/s]
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 122/122 [00:00<00:00, 9513.19it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 1115/1115 [00:00<00:00, 5119.33it/s]
Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 10854.85it/s]
Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 113134.00it/s]
Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 129446.66it/s]
Processing length_left with len: 100%|██████████| 122/122 [00:00<00:00, 185736.87it/s]
Processing length_right with len: 100%|██████████| 1115/1115 [00:00<00:00, 583778.42it/s]
Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 90215.99it/s]
Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 93243.92it/s]
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 237/237 [00:00<00:00, 8866.58it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2300/2300 [00:00<00:00, 5131.75it/s]
Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 147786.31it/s]
Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 119176.36it/s]
Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 111582.90it/s]
Processing length_left with len: 100%|██████████| 237/237 [00:00<00:00, 354132.54it/s]
Processing length_right with len: 100%|██████████| 2300/2300 [00:00<00:00, 543579.15it/s]
Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 103730.57it/s]
Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 89680.20it/s]
In [4]:
preprocessor.context
Out[4]:
{'filter_unit': <matchzoo.preprocessors.units.frequency_filter.FrequencyFilter at 0x7efc21e7ce48>,
'vocab_unit': <matchzoo.preprocessors.units.vocabulary.Vocabulary at 0x7efc16b9ec88>,
'vocab_size': 16674,
'embedding_input_dim': 16674,
'input_shapes': [(10,), (100,)]}
In [5]:
model = mz.models.ArcII()
# load `input_shapes` and `embedding_input_dim` (vocab_size)
model.params.update(preprocessor.context)
model.params['task'] = ranking_task
model.params['embedding_output_dim'] = 100
model.params['embedding_trainable'] = True
model.params['num_blocks'] = 2
model.params['kernel_1d_count'] = 32
model.params['kernel_1d_size'] = 3
model.params['kernel_2d_count'] = [64, 64]
model.params['kernel_2d_size'] = [3, 3]
model.params['pool_2d_size'] = [[3, 3], [3, 3]]
model.params['optimizer'] = 'adam'
model.build()
model.compile()
print(model.params)
model_class <class 'matchzoo.models.arcii.ArcII'>
input_shapes [(10,), (100,)]
task Ranking Task
optimizer adam
with_embedding True
embedding_input_dim 16674
embedding_output_dim 100
embedding_trainable True
num_blocks 2
kernel_1d_count 32
kernel_1d_size 3
kernel_2d_count [64, 64]
kernel_2d_size [3, 3]
activation relu
pool_2d_size [[3, 3], [3, 3]]
padding same
dropout_rate 0.0
In [6]:
model.backend.summary()
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
text_left (InputLayer) (None, 10) 0
__________________________________________________________________________________________________
text_right (InputLayer) (None, 100) 0
__________________________________________________________________________________________________
embedding (Embedding) multiple 1667400 text_left[0][0]
text_right[0][0]
__________________________________________________________________________________________________
conv1d_1 (Conv1D) (None, 10, 32) 9632 embedding[0][0]
__________________________________________________________________________________________________
conv1d_2 (Conv1D) (None, 100, 32) 9632 embedding[1][0]
__________________________________________________________________________________________________
matching_layer_1 (MatchingLayer (None, 10, 100, 32) 0 conv1d_1[0][0]
conv1d_2[0][0]
__________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, 10, 100, 64) 18496 matching_layer_1[0][0]
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D) (None, 3, 33, 64) 0 conv2d_1[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D) (None, 3, 33, 64) 36928 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D) (None, 1, 11, 64) 0 conv2d_2[0][0]
__________________________________________________________________________________________________
flatten_1 (Flatten) (None, 704) 0 max_pooling2d_2[0][0]
__________________________________________________________________________________________________
dropout_1 (Dropout) (None, 704) 0 flatten_1[0][0]
__________________________________________________________________________________________________
dense_1 (Dense) (None, 1) 705 dropout_1[0][0]
==================================================================================================
Total params: 1,742,793
Trainable params: 1,742,793
Non-trainable params: 0
__________________________________________________________________________________________________
In [7]:
embedding_matrix = glove_embedding.build_matrix(preprocessor.context['vocab_unit'].state['term_index'])
In [8]:
model.load_embedding_matrix(embedding_matrix)
In [9]:
test_x, test_y = test_pack_processed[:].unpack()
evaluate = mz.callbacks.EvaluateAllMetrics(model, x=test_x, y=test_y, batch_size=len(test_y))
In [10]:
train_generator = mz.DataGenerator(
train_pack_processed,
mode='pair',
num_dup=2,
num_neg=1,
batch_size=20
)
print('num batches:', len(train_generator))
num batches: 102
In [11]:
history = model.fit_generator(train_generator,
epochs=30,
callbacks=[evaluate],
workers=30,
use_multiprocessing=True)
Epoch 1/30
102/102 [==============================] - 6s 61ms/step - loss: 0.6123
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6165843661519929 - normalized_discounted_cumulative_gain@5(0.0): 0.6639951229938149 - mean_average_precision(0.0): 0.6255013171310638
Epoch 2/30
102/102 [==============================] - 11s 107ms/step - loss: 0.3213
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5795022399236163 - normalized_discounted_cumulative_gain@5(0.0): 0.6400033764936961 - mean_average_precision(0.0): 0.5927358064245385
Epoch 3/30
102/102 [==============================] - 12s 116ms/step - loss: 0.1556
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5734366624818106 - normalized_discounted_cumulative_gain@5(0.0): 0.6320835504730066 - mean_average_precision(0.0): 0.5873376891991933
Epoch 4/30
102/102 [==============================] - 12s 121ms/step - loss: 0.0966
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6047140052395196 - normalized_discounted_cumulative_gain@5(0.0): 0.6646231653619664 - mean_average_precision(0.0): 0.6177779169838045
Epoch 5/30
102/102 [==============================] - 12s 117ms/step - loss: 0.0688
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5804830097430296 - normalized_discounted_cumulative_gain@5(0.0): 0.6397107091735946 - mean_average_precision(0.0): 0.5861104257070671
Epoch 6/30
102/102 [==============================] - 12s 119ms/step - loss: 0.0741
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5733720686699197 - normalized_discounted_cumulative_gain@5(0.0): 0.6297166660287926 - mean_average_precision(0.0): 0.5793064781802681
Epoch 7/30
102/102 [==============================] - 12s 118ms/step - loss: 0.0545
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5612000374349604 - normalized_discounted_cumulative_gain@5(0.0): 0.6275035028226839 - mean_average_precision(0.0): 0.5812912862211518
Epoch 8/30
102/102 [==============================] - 13s 129ms/step - loss: 0.0351
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5695536121652267 - normalized_discounted_cumulative_gain@5(0.0): 0.6230819298416224 - mean_average_precision(0.0): 0.5829102115308369
Epoch 9/30
102/102 [==============================] - 13s 127ms/step - loss: 0.0546
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5747202712592626 - normalized_discounted_cumulative_gain@5(0.0): 0.6283361484936623 - mean_average_precision(0.0): 0.5858063874997032
Epoch 10/30
102/102 [==============================] - 12s 120ms/step - loss: 0.0363
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5608760792872554 - normalized_discounted_cumulative_gain@5(0.0): 0.6165261187651955 - mean_average_precision(0.0): 0.5804023750891116
Epoch 11/30
102/102 [==============================] - 12s 121ms/step - loss: 0.0316
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5526488015162425 - normalized_discounted_cumulative_gain@5(0.0): 0.6190121151432394 - mean_average_precision(0.0): 0.5724829930234496
Epoch 12/30
102/102 [==============================] - 12s 122ms/step - loss: 0.0286
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5574638704067251 - normalized_discounted_cumulative_gain@5(0.0): 0.6221425949954883 - mean_average_precision(0.0): 0.5790691069305163
Epoch 13/30
102/102 [==============================] - 14s 135ms/step - loss: 0.0232
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5530705667805896 - normalized_discounted_cumulative_gain@5(0.0): 0.6045382570104455 - mean_average_precision(0.0): 0.560135777928777
Epoch 14/30
102/102 [==============================] - 12s 122ms/step - loss: 0.0232
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5551975290816926 - normalized_discounted_cumulative_gain@5(0.0): 0.6199269060404192 - mean_average_precision(0.0): 0.5695024812501234
Epoch 15/30
102/102 [==============================] - 12s 120ms/step - loss: 0.0186
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5798942973253254 - normalized_discounted_cumulative_gain@5(0.0): 0.6326798983491299 - mean_average_precision(0.0): 0.5914300748272545
Epoch 16/30
102/102 [==============================] - 12s 120ms/step - loss: 0.0224
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5549313451763698 - normalized_discounted_cumulative_gain@5(0.0): 0.609066461479371 - mean_average_precision(0.0): 0.5696612552826874
Epoch 17/30
102/102 [==============================] - 14s 134ms/step - loss: 0.0320
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5362590661756249 - normalized_discounted_cumulative_gain@5(0.0): 0.5960587553325881 - mean_average_precision(0.0): 0.5529867856626977
Epoch 18/30
102/102 [==============================] - 13s 123ms/step - loss: 0.0231
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5709133081694716 - normalized_discounted_cumulative_gain@5(0.0): 0.6294306655133045 - mean_average_precision(0.0): 0.5833828714158493
Epoch 19/30
102/102 [==============================] - 12s 120ms/step - loss: 0.0210
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5408055010433294 - normalized_discounted_cumulative_gain@5(0.0): 0.6188813340014495 - mean_average_precision(0.0): 0.5639203433881982
Epoch 20/30
102/102 [==============================] - 11s 111ms/step - loss: 0.0099
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5857209979914645 - normalized_discounted_cumulative_gain@5(0.0): 0.6369323440919344 - mean_average_precision(0.0): 0.59214541005708
Epoch 21/30
102/102 [==============================] - 11s 108ms/step - loss: 0.0314
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5621754667046626 - normalized_discounted_cumulative_gain@5(0.0): 0.6250852206304124 - mean_average_precision(0.0): 0.5746827039867791
Epoch 22/30
102/102 [==============================] - 11s 108ms/step - loss: 0.0346
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5245291889962129 - normalized_discounted_cumulative_gain@5(0.0): 0.5957453960380752 - mean_average_precision(0.0): 0.5563949155969331
Epoch 23/30
102/102 [==============================] - 11s 104ms/step - loss: 0.0180
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5649023738386006 - normalized_discounted_cumulative_gain@5(0.0): 0.6281559244518438 - mean_average_precision(0.0): 0.5862423847100566
Epoch 24/30
102/102 [==============================] - 11s 110ms/step - loss: 0.0326
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5595631369578311 - normalized_discounted_cumulative_gain@5(0.0): 0.6216538205632761 - mean_average_precision(0.0): 0.5770444731020851
Epoch 25/30
102/102 [==============================] - 11s 107ms/step - loss: 0.0212
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5611397782325962 - normalized_discounted_cumulative_gain@5(0.0): 0.6275063453800273 - mean_average_precision(0.0): 0.5817784786583774
Epoch 26/30
102/102 [==============================] - 11s 110ms/step - loss: 0.0297
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5332907973060383 - normalized_discounted_cumulative_gain@5(0.0): 0.5988259655236242 - mean_average_precision(0.0): 0.5647126123155829
Epoch 27/30
102/102 [==============================] - 11s 108ms/step - loss: 0.0244
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5448159675702459 - normalized_discounted_cumulative_gain@5(0.0): 0.5945159262523649 - mean_average_precision(0.0): 0.5558586926930239
Epoch 28/30
102/102 [==============================] - 11s 108ms/step - loss: 0.0157
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.570165361029481 - normalized_discounted_cumulative_gain@5(0.0): 0.6316999513916105 - mean_average_precision(0.0): 0.5947381662168211
Epoch 29/30
102/102 [==============================] - 11s 107ms/step - loss: 0.0360
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.586639186111628 - normalized_discounted_cumulative_gain@5(0.0): 0.6352015410005084 - mean_average_precision(0.0): 0.596046073559104
Epoch 30/30
102/102 [==============================] - 11s 108ms/step - loss: 0.0149
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.57738099248289 - normalized_discounted_cumulative_gain@5(0.0): 0.63122082754378 - mean_average_precision(0.0): 0.5877516720454762
In [12]:
append_params_to_readme(model)
Content source: faneshion/MatchZoo
Similar notebooks: