In [1]:
%run init.ipynb
Using TensorFlow backend.
matchzoo version 2.1.0
data loading ...
data loaded as `train_pack_raw` `dev_pack_raw` `test_pack_raw`
`ranking_task` initialized with metrics [normalized_discounted_cumulative_gain@3(0.0), normalized_discounted_cumulative_gain@5(0.0), mean_average_precision(0.0)]
loading embedding ...
embedding loaded as `glove_embedding`
In [2]:
preprocessor = mz.preprocessors.BasicPreprocessor(fixed_length_left=10, fixed_length_right=100, remove_stop_words=False)
In [3]:
train_pack_processed = preprocessor.fit_transform(train_pack_raw)
dev_pack_processed = preprocessor.transform(dev_pack_raw)
test_pack_processed = preprocessor.transform(test_pack_raw)
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 9354.30it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 18841/18841 [00:03<00:00, 5242.39it/s]
Processing text_right with append: 100%|██████████| 18841/18841 [00:00<00:00, 959307.59it/s]
Building FrequencyFilter from a datapack.: 100%|██████████| 18841/18841 [00:00<00:00, 144447.70it/s]
Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 151024.12it/s]
Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 818306.55it/s]
Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 742442.92it/s]
Building Vocabulary from a datapack.: 100%|██████████| 404432/404432 [00:00<00:00, 2591651.03it/s]
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 9654.46it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 18841/18841 [00:03<00:00, 5546.36it/s]
Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 138535.63it/s]
Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 254476.95it/s]
Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 125907.17it/s]
Processing length_left with len: 100%|██████████| 2118/2118 [00:00<00:00, 631516.02it/s]
Processing length_right with len: 100%|██████████| 18841/18841 [00:00<00:00, 868328.96it/s]
Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 138403.01it/s]
Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 61442.21it/s]
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 122/122 [00:00<00:00, 9297.98it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 1115/1115 [00:00<00:00, 5474.46it/s]
Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 123108.59it/s]
Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 109807.96it/s]
Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 142493.87it/s]
Processing length_left with len: 100%|██████████| 122/122 [00:00<00:00, 193974.64it/s]
Processing length_right with len: 100%|██████████| 1115/1115 [00:00<00:00, 688349.86it/s]
Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 110281.27it/s]
Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 86857.14it/s]
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 237/237 [00:00<00:00, 9041.67it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2300/2300 [00:00<00:00, 4067.08it/s]
Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 123996.13it/s]
Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 143565.86it/s]
Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 129280.34it/s]
Processing length_left with len: 100%|██████████| 237/237 [00:00<00:00, 309383.77it/s]
Processing length_right with len: 100%|██████████| 2300/2300 [00:00<00:00, 887315.97it/s]
Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 101288.98it/s]
Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 85815.05it/s]
In [4]:
preprocessor.context
Out[4]:
{'filter_unit': <matchzoo.preprocessors.units.frequency_filter.FrequencyFilter at 0x1115c2f28>,
'vocab_unit': <matchzoo.preprocessors.units.vocabulary.Vocabulary at 0x13ec386d8>,
'vocab_size': 16674,
'embedding_input_dim': 16674,
'input_shapes': [(10,), (100,)]}
In [5]:
model = mz.models.DRMMTKS()
# load `input_shapes` and `embedding_input_dim` (vocab_size)
model.params.update(preprocessor.context)
model.params['task'] = ranking_task
model.params['mask_value'] = -1
model.params['embedding_output_dim'] = glove_embedding.output_dim
model.params['embedding_trainable'] = True
model.params['top_k'] = 20
model.params['mlp_num_layers'] = 1
model.params['mlp_num_units'] = 5
model.params['mlp_num_fan_out'] = 1
model.params['mlp_activation_func'] = 'relu'
model.params['optimizer'] = 'adadelta'
model.build()
model.compile()
print(model.params)
model_class <class 'matchzoo.models.drmmtks.DRMMTKS'>
input_shapes [(10,), (100,)]
task Ranking Task
optimizer adadelta
with_embedding True
embedding_input_dim 16674
embedding_output_dim 100
embedding_trainable True
with_multi_layer_perceptron True
mlp_num_units 5
mlp_num_layers 1
mlp_num_fan_out 1
mlp_activation_func relu
mask_value -1
top_k 20
In [6]:
model.backend.summary()
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
text_left (InputLayer) (None, 10) 0
__________________________________________________________________________________________________
text_right (InputLayer) (None, 100) 0
__________________________________________________________________________________________________
embedding (Embedding) multiple 1667400 text_left[0][0]
text_right[0][0]
__________________________________________________________________________________________________
dot_1 (Dot) (None, 10, 100) 0 embedding[0][0]
embedding[1][0]
__________________________________________________________________________________________________
dense_1 (Dense) (None, 10, 1) 100 embedding[0][0]
__________________________________________________________________________________________________
lambda_1 (Lambda) (None, 10, 10) 0 dot_1[0][0]
__________________________________________________________________________________________________
attention_mask (Lambda) (None, 10, 1) 0 dense_1[0][0]
__________________________________________________________________________________________________
dense_2 (Dense) (None, 10, 5) 55 lambda_1[0][0]
__________________________________________________________________________________________________
attention_probs (Lambda) (None, 10, 1) 0 attention_mask[0][0]
__________________________________________________________________________________________________
dense_3 (Dense) (None, 10, 1) 6 dense_2[0][0]
__________________________________________________________________________________________________
dot_2 (Dot) (None, 1, 1) 0 attention_probs[0][0]
dense_3[0][0]
__________________________________________________________________________________________________
flatten_1 (Flatten) (None, 1) 0 dot_2[0][0]
__________________________________________________________________________________________________
dense_4 (Dense) (None, 1) 2 flatten_1[0][0]
==================================================================================================
Total params: 1,667,563
Trainable params: 1,667,563
Non-trainable params: 0
__________________________________________________________________________________________________
In [7]:
term_index = preprocessor.context['vocab_unit'].state['term_index']
embedding_matrix = glove_embedding.build_matrix(term_index)
l2_norm = np.sqrt((embedding_matrix * embedding_matrix).sum(axis=1))
embedding_matrix = embedding_matrix / l2_norm[:, np.newaxis]
In [8]:
model.load_embedding_matrix(embedding_matrix)
In [9]:
test_x, test_y = test_pack_processed.unpack()
evaluate = mz.callbacks.EvaluateAllMetrics(model, x=test_x, y=test_y, batch_size=len(test_x))
In [12]:
train_generator = mz.DataGenerator(
train_pack_processed,
mode='pair',
num_dup=2,
num_neg=1,
batch_size=20
)
print('num batches:', len(train_generator))
num batches: 408
In [13]:
history = model.fit_generator(train_generator, epochs=30, callbacks=[evaluate], workers=4, use_multiprocessing=True)
Epoch 1/30
408/408 [==============================] - 7s 17ms/step - loss: 0.5841
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5988918016460031 - normalized_discounted_cumulative_gain@5(0.0): 0.6590526112135905 - mean_average_precision(0.0): 0.6149729212603963
Epoch 2/30
408/408 [==============================] - 16s 40ms/step - loss: 0.1286
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5986225503033814 - normalized_discounted_cumulative_gain@5(0.0): 0.6573151832954626 - mean_average_precision(0.0): 0.6081435794322055
Epoch 3/30
408/408 [==============================] - 17s 41ms/step - loss: 0.0236
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5770594699106932 - normalized_discounted_cumulative_gain@5(0.0): 0.642764019599169 - mean_average_precision(0.0): 0.5915165228479636
Epoch 4/30
408/408 [==============================] - 20s 48ms/step - loss: 0.0085
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5630076023551297 - normalized_discounted_cumulative_gain@5(0.0): 0.6314235897212269 - mean_average_precision(0.0): 0.5811220052309326
Epoch 5/30
408/408 [==============================] - 21s 52ms/step - loss: 0.0039
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5721875341813376 - normalized_discounted_cumulative_gain@5(0.0): 0.6430346137953176 - mean_average_precision(0.0): 0.5887684999840622
Epoch 6/30
408/408 [==============================] - 18s 44ms/step - loss: 0.0025
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5635600485727307 - normalized_discounted_cumulative_gain@5(0.0): 0.636814704018869 - mean_average_precision(0.0): 0.5831658911191673
Epoch 7/30
408/408 [==============================] - 19s 47ms/step - loss: 0.0019
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5624051725623126 - normalized_discounted_cumulative_gain@5(0.0): 0.6362646813359831 - mean_average_precision(0.0): 0.5830741350566369
Epoch 8/30
408/408 [==============================] - 20s 48ms/step - loss: 0.0020
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5627439034097886 - normalized_discounted_cumulative_gain@5(0.0): 0.6330671467013023 - mean_average_precision(0.0): 0.5806811015299475
Epoch 9/30
408/408 [==============================] - 21s 51ms/step - loss: 0.0018
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5679681248986371 - normalized_discounted_cumulative_gain@5(0.0): 0.6363760316713785 - mean_average_precision(0.0): 0.5851179931317684
Epoch 10/30
408/408 [==============================] - 22s 54ms/step - loss: 0.0020
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5745461650682577 - normalized_discounted_cumulative_gain@5(0.0): 0.6407437747766428 - mean_average_precision(0.0): 0.5893184033052796
Epoch 11/30
408/408 [==============================] - 21s 50ms/step - loss: 0.0020
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5684307665143323 - normalized_discounted_cumulative_gain@5(0.0): 0.634110297384481 - mean_average_precision(0.0): 0.5844463619617193
Epoch 12/30
408/408 [==============================] - 20s 48ms/step - loss: 0.0018
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.566182378404992 - normalized_discounted_cumulative_gain@5(0.0): 0.6318466678703829 - mean_average_precision(0.0): 0.583959239462311
Epoch 13/30
408/408 [==============================] - 21s 50ms/step - loss: 0.0020
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5671372070359241 - normalized_discounted_cumulative_gain@5(0.0): 0.6288034745152886 - mean_average_precision(0.0): 0.5823057242012007
Epoch 14/30
408/408 [==============================] - 24s 59ms/step - loss: 0.0020
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5688445292639432 - normalized_discounted_cumulative_gain@5(0.0): 0.6289506917256232 - mean_average_precision(0.0): 0.5823084213744116
Epoch 15/30
408/408 [==============================] - 21s 51ms/step - loss: 0.0016
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5616594491502604 - normalized_discounted_cumulative_gain@5(0.0): 0.6274555733102793 - mean_average_precision(0.0): 0.5803460611842033
Epoch 16/30
408/408 [==============================] - 24s 59ms/step - loss: 0.0019
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.56743611685671 - normalized_discounted_cumulative_gain@5(0.0): 0.6279655493568921 - mean_average_precision(0.0): 0.5833285217616007
Epoch 17/30
408/408 [==============================] - 21s 51ms/step - loss: 0.0019
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.56743611685671 - normalized_discounted_cumulative_gain@5(0.0): 0.629412929341684 - mean_average_precision(0.0): 0.5831775027001493
Epoch 18/30
408/408 [==============================] - 21s 53ms/step - loss: 0.0017
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5616594491502604 - normalized_discounted_cumulative_gain@5(0.0): 0.6230767811708631 - mean_average_precision(0.0): 0.5778751707969239
Epoch 19/30
408/408 [==============================] - 21s 51ms/step - loss: 0.0014
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5624755943132025 - normalized_discounted_cumulative_gain@5(0.0): 0.6196682608534165 - mean_average_precision(0.0): 0.5773872198947884
Epoch 20/30
408/408 [==============================] - 21s 51ms/step - loss: 0.0017
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.556962625552094 - normalized_discounted_cumulative_gain@5(0.0): 0.6181385580129402 - mean_average_precision(0.0): 0.5762216174819781
Epoch 21/30
408/408 [==============================] - 22s 53ms/step - loss: 0.0018
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5532956624869947 - normalized_discounted_cumulative_gain@5(0.0): 0.614101122041995 - mean_average_precision(0.0): 0.570779766786963
Epoch 22/30
408/408 [==============================] - 21s 51ms/step - loss: 0.0019
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5580675179872963 - normalized_discounted_cumulative_gain@5(0.0): 0.615310109292882 - mean_average_precision(0.0): 0.5723212571524482
Epoch 23/30
408/408 [==============================] - 21s 51ms/step - loss: 0.0017
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.555957813345946 - normalized_discounted_cumulative_gain@5(0.0): 0.6168348059854845 - mean_average_precision(0.0): 0.5745875152731366
Epoch 24/30
408/408 [==============================] - 20s 50ms/step - loss: 0.0020
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5545892219654028 - normalized_discounted_cumulative_gain@5(0.0): 0.6161692037842695 - mean_average_precision(0.0): 0.5737084716725741
Epoch 25/30
408/408 [==============================] - 21s 51ms/step - loss: 0.0015
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5572513728243542 - normalized_discounted_cumulative_gain@5(0.0): 0.6149353178966624 - mean_average_precision(0.0): 0.5726138530937782
Epoch 26/30
408/408 [==============================] - 20s 50ms/step - loss: 0.0019
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5530319635416537 - normalized_discounted_cumulative_gain@5(0.0): 0.6157477651673351 - mean_average_precision(0.0): 0.5719727220729004
Epoch 27/30
408/408 [==============================] - 22s 55ms/step - loss: 0.0018
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5535491988838099 - normalized_discounted_cumulative_gain@5(0.0): 0.6157701973864509 - mean_average_precision(0.0): 0.5706400633221479
Epoch 28/30
408/408 [==============================] - 24s 58ms/step - loss: 0.0017
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.556713699306524 - normalized_discounted_cumulative_gain@5(0.0): 0.6164860410473986 - mean_average_precision(0.0): 0.5726880826961166
Epoch 29/30
408/408 [==============================] - 25s 61ms/step - loss: 0.0015
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5586585674438334 - normalized_discounted_cumulative_gain@5(0.0): 0.6187911692779081 - mean_average_precision(0.0): 0.5763274637089153
Epoch 30/30
408/408 [==============================] - 24s 58ms/step - loss: 0.0014
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5642215197801578 - normalized_discounted_cumulative_gain@5(0.0): 0.6253202724847128 - mean_average_precision(0.0): 0.5820402362707493
In [ ]:
# append_params_to_readme(model)
Content source: faneshion/MatchZoo
Similar notebooks: