In [1]:
%run init.ipynb
Using TensorFlow backend.
matchzoo version 2.1.0
data loading ...
data loaded as `train_pack_raw` `dev_pack_raw` `test_pack_raw`
`ranking_task` initialized with metrics [normalized_discounted_cumulative_gain@3(0.0), normalized_discounted_cumulative_gain@5(0.0), mean_average_precision(0.0)]
loading embedding ...
embedding loaded as `glove_embedding`
In [2]:
preprocessor = mz.preprocessors.CDSSMPreprocessor(fixed_length_left=10, fixed_length_right=10)
train_pack_processed = preprocessor.fit_transform(train_pack_raw)
valid_pack_processed = preprocessor.transform(dev_pack_raw)
test_pack_processed = preprocessor.transform(test_pack_raw)
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter: 100%|██████████| 2118/2118 [00:00<00:00, 5365.33it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter: 100%|██████████| 18841/18841 [00:05<00:00, 3205.80it/s]
Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 310569.71it/s]
Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 349915.35it/s]
Building Vocabulary from a datapack.: 100%|██████████| 1614998/1614998 [00:00<00:00, 3031577.24it/s]
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 8384.04it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 18841/18841 [00:04<00:00, 3939.48it/s]
Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 125562.34it/s]
Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 96621.02it/s]
Processing text_left with chain_transform of NgramLetter => WordHashing: 100%|██████████| 2118/2118 [00:07<00:00, 269.84it/s]
Processing text_right with chain_transform of NgramLetter => WordHashing: 100%|██████████| 18841/18841 [01:27<00:00, 216.37it/s]
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 122/122 [00:00<00:00, 7746.89it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 1115/1115 [00:14<00:00, 77.28it/s]
Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 63447.62it/s]
Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 88946.88it/s]
Processing text_left with chain_transform of NgramLetter => WordHashing: 100%|██████████| 122/122 [00:00<00:00, 273.77it/s]
Processing text_right with chain_transform of NgramLetter => WordHashing: 100%|██████████| 1115/1115 [00:04<00:00, 226.13it/s]
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 237/237 [00:00<00:00, 8707.90it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 2300/2300 [00:01<00:00, 2237.22it/s]
Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 101299.30it/s]
Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 97484.78it/s]
Processing text_left with chain_transform of NgramLetter => WordHashing: 100%|██████████| 237/237 [00:00<00:00, 269.12it/s]
Processing text_right with chain_transform of NgramLetter => WordHashing: 100%|██████████| 2300/2300 [00:10<00:00, 212.35it/s]
In [3]:
model = mz.models.CDSSM()
model.params['input_shapes'] = preprocessor.context['input_shapes']
model.params['task'] = ranking_task
model.params['filters'] = 64
model.params['kernel_size'] = 3
model.params['strides'] = 1
model.params['padding'] = 'same'
model.params['conv_activation_func'] = 'tanh'
model.params['w_initializer'] = 'glorot_normal'
model.params['b_initializer'] = 'zeros'
model.params['mlp_num_layers'] = 1
model.params['mlp_num_units'] = 64
model.params['mlp_num_fan_out'] = 64
model.params['mlp_activation_func'] = 'tanh'
model.params['dropout_rate'] = 0.8
model.params['optimizer'] = 'adadelta'
model.guess_and_fill_missing_params()
model.build()
model.compile()
model.backend.summary()
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
text_left (InputLayer) (None, 10, 9644) 0
__________________________________________________________________________________________________
text_right (InputLayer) (None, 10, 9644) 0
__________________________________________________________________________________________________
conv1d_1 (Conv1D) (None, 10, 64) 1851712 text_left[0][0]
__________________________________________________________________________________________________
conv1d_2 (Conv1D) (None, 10, 64) 1851712 text_right[0][0]
__________________________________________________________________________________________________
dropout_1 (Dropout) (None, 10, 64) 0 conv1d_1[0][0]
__________________________________________________________________________________________________
dropout_2 (Dropout) (None, 10, 64) 0 conv1d_2[0][0]
__________________________________________________________________________________________________
global_max_pooling1d_1 (GlobalM (None, 64) 0 dropout_1[0][0]
__________________________________________________________________________________________________
global_max_pooling1d_2 (GlobalM (None, 64) 0 dropout_2[0][0]
__________________________________________________________________________________________________
dense_1 (Dense) (None, 64) 4160 global_max_pooling1d_1[0][0]
__________________________________________________________________________________________________
dense_3 (Dense) (None, 64) 4160 global_max_pooling1d_2[0][0]
__________________________________________________________________________________________________
dense_2 (Dense) (None, 64) 4160 dense_1[0][0]
__________________________________________________________________________________________________
dense_4 (Dense) (None, 64) 4160 dense_3[0][0]
__________________________________________________________________________________________________
dot_1 (Dot) (None, 1) 0 dense_2[0][0]
dense_4[0][0]
__________________________________________________________________________________________________
dense_5 (Dense) (None, 1) 2 dot_1[0][0]
==================================================================================================
Total params: 3,720,066
Trainable params: 3,720,066
Non-trainable params: 0
__________________________________________________________________________________________________
In [4]:
pred_x, pred_y = test_pack_processed[:].unpack()
evaluate = mz.callbacks.EvaluateAllMetrics(model, x=pred_x, y=pred_y, batch_size=len(pred_x))
train_generator = mz.DataGenerator(
train_pack_processed,
mode='pair',
num_dup=2,
num_neg=1,
batch_size=20
)
print('num batches:', len(train_generator))
num batches: 102
In [5]:
history = model.fit_generator(train_generator, epochs=20, callbacks=[evaluate], workers=1, use_multiprocessing=False)
Epoch 1/20
102/102 [==============================] - 65s 635ms/step - loss: 0.8021
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.42304382290227854 - normalized_discounted_cumulative_gain@5(0.0): 0.49915948768338086 - mean_average_precision(0.0): 0.46037758752542035
Epoch 2/20
102/102 [==============================] - 45s 445ms/step - loss: 0.5966
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.43781763271520285 - normalized_discounted_cumulative_gain@5(0.0): 0.520097599085372 - mean_average_precision(0.0): 0.4762598411822459
Epoch 3/20
102/102 [==============================] - 46s 447ms/step - loss: 0.4992
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.44923101748788413 - normalized_discounted_cumulative_gain@5(0.0): 0.5136672947113214 - mean_average_precision(0.0): 0.4803110559647868
Epoch 4/20
102/102 [==============================] - 46s 446ms/step - loss: 0.4143
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.467714954371615 - normalized_discounted_cumulative_gain@5(0.0): 0.5353130653753986 - mean_average_precision(0.0): 0.5017560318255102
Epoch 5/20
102/102 [==============================] - 46s 451ms/step - loss: 0.3489
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4706511291875292 - normalized_discounted_cumulative_gain@5(0.0): 0.5275832328072992 - mean_average_precision(0.0): 0.5026243479583462
Epoch 6/20
102/102 [==============================] - 45s 443ms/step - loss: 0.3231
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4641151831570107 - normalized_discounted_cumulative_gain@5(0.0): 0.5219564667466021 - mean_average_precision(0.0): 0.4934049132027672
Epoch 7/20
102/102 [==============================] - 44s 433ms/step - loss: 0.2695
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4760514687477512 - normalized_discounted_cumulative_gain@5(0.0): 0.5285019348702702 - mean_average_precision(0.0): 0.49994736585416333
Epoch 8/20
102/102 [==============================] - 47s 457ms/step - loss: 0.2331
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4661781341235399 - normalized_discounted_cumulative_gain@5(0.0): 0.5260071867435453 - mean_average_precision(0.0): 0.4922605321622356
Epoch 9/20
102/102 [==============================] - 46s 453ms/step - loss: 0.1942
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4645719968337067 - normalized_discounted_cumulative_gain@5(0.0): 0.5238558790194195 - mean_average_precision(0.0): 0.48851468090847294
Epoch 10/20
102/102 [==============================] - 45s 444ms/step - loss: 0.1734
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4600910137969285 - normalized_discounted_cumulative_gain@5(0.0): 0.5320923473092672 - mean_average_precision(0.0): 0.48703092961044614
Epoch 11/20
102/102 [==============================] - 47s 464ms/step - loss: 0.1644
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.45786306386326225 - normalized_discounted_cumulative_gain@5(0.0): 0.5246949873542252 - mean_average_precision(0.0): 0.48502089087514016
Epoch 12/20
102/102 [==============================] - 45s 443ms/step - loss: 0.1560
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4567332855642369 - normalized_discounted_cumulative_gain@5(0.0): 0.528074374789356 - mean_average_precision(0.0): 0.4905494464640722
Epoch 13/20
102/102 [==============================] - 45s 440ms/step - loss: 0.1365
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4698836510431016 - normalized_discounted_cumulative_gain@5(0.0): 0.5317255666034969 - mean_average_precision(0.0): 0.49152222966181813
Epoch 14/20
102/102 [==============================] - 45s 437ms/step - loss: 0.1263
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.46841048236088156 - normalized_discounted_cumulative_gain@5(0.0): 0.5197949838102164 - mean_average_precision(0.0): 0.4887341126171474
Epoch 15/20
102/102 [==============================] - 46s 449ms/step - loss: 0.1208
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4591952265806063 - normalized_discounted_cumulative_gain@5(0.0): 0.5306329604843507 - mean_average_precision(0.0): 0.4956899590808506
Epoch 16/20
102/102 [==============================] - 44s 430ms/step - loss: 0.0977
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4698790408918565 - normalized_discounted_cumulative_gain@5(0.0): 0.5355447042513717 - mean_average_precision(0.0): 0.5005823464725863
Epoch 17/20
102/102 [==============================] - 44s 436ms/step - loss: 0.0975
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4699380823064665 - normalized_discounted_cumulative_gain@5(0.0): 0.5335843828018585 - mean_average_precision(0.0): 0.4945873841691485
Epoch 18/20
102/102 [==============================] - 44s 431ms/step - loss: 0.1070
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4558123512520484 - normalized_discounted_cumulative_gain@5(0.0): 0.5280824209271964 - mean_average_precision(0.0): 0.49009730599920476
Epoch 19/20
102/102 [==============================] - 45s 446ms/step - loss: 0.0846
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.47468488947984894 - normalized_discounted_cumulative_gain@5(0.0): 0.5462940161373581 - mean_average_precision(0.0): 0.5050693971440435
Epoch 20/20
102/102 [==============================] - 43s 425ms/step - loss: 0.0833
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4571219999869289 - normalized_discounted_cumulative_gain@5(0.0): 0.527098973778668 - mean_average_precision(0.0): 0.4884211445807594
In [ ]:
Content source: faneshion/MatchZoo
Similar notebooks: