In [1]:
%run init.ipynb


Using TensorFlow backend.
matchzoo version 2.1.0

data loading ...
data loaded as `train_pack_raw` `dev_pack_raw` `test_pack_raw`
`ranking_task` initialized with metrics [normalized_discounted_cumulative_gain@3(0.0), normalized_discounted_cumulative_gain@5(0.0), mean_average_precision(0.0)]
loading embedding ...
embedding loaded as `glove_embedding`

In [2]:
preprocessor = mz.preprocessors.DSSMPreprocessor()
train_pack_processed = preprocessor.fit_transform(train_pack_raw)
valid_pack_processed = preprocessor.transform(dev_pack_raw)
test_pack_processed = preprocessor.transform(test_pack_raw)


Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter: 100%|██████████| 2118/2118 [00:00<00:00, 3587.72it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter: 100%|██████████| 18841/18841 [00:04<00:00, 4528.13it/s]
Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 592156.77it/s]
Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 432217.30it/s]
Building Vocabulary from a datapack.: 100%|██████████| 1614998/1614998 [00:00<00:00, 4239505.32it/s]
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter => WordHashing: 100%|██████████| 2118/2118 [00:00<00:00, 2709.71it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter => WordHashing: 100%|██████████| 18841/18841 [00:11<00:00, 1656.57it/s]
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter => WordHashing: 100%|██████████| 122/122 [00:00<00:00, 1120.91it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter => WordHashing: 100%|██████████| 1115/1115 [00:00<00:00, 1895.34it/s]
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter => WordHashing: 100%|██████████| 237/237 [00:00<00:00, 1910.44it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter => WordHashing: 100%|██████████| 2300/2300 [00:01<00:00, 1630.79it/s]

In [3]:
preprocessor.context


Out[3]:
{'vocab_unit': <matchzoo.preprocessors.units.vocabulary.Vocabulary at 0x7f0d7fe29fd0>,
 'vocab_size': 9645,
 'embedding_input_dim': 9645,
 'input_shapes': [(9645,), (9645,)]}

In [4]:
ranking_task = mz.tasks.Ranking(loss=mz.losses.RankCrossEntropyLoss(num_neg=4))
ranking_task.metrics = [
    mz.metrics.NormalizedDiscountedCumulativeGain(k=3),
    mz.metrics.NormalizedDiscountedCumulativeGain(k=5),
    mz.metrics.MeanAveragePrecision()
]

In [5]:
model = mz.models.DSSM()
model.params['input_shapes'] = preprocessor.context['input_shapes']
model.params['task'] = ranking_task
model.params['mlp_num_layers'] = 3
model.params['mlp_num_units'] = 300
model.params['mlp_num_fan_out'] = 128
model.params['mlp_activation_func'] = 'relu'
model.guess_and_fill_missing_params()
model.build()
model.compile()
model.backend.summary()

append_params_to_readme(model)


__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
text_left (InputLayer)          (None, 9645)         0                                            
__________________________________________________________________________________________________
text_right (InputLayer)         (None, 9645)         0                                            
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 300)          2893800     text_left[0][0]                  
__________________________________________________________________________________________________
dense_5 (Dense)                 (None, 300)          2893800     text_right[0][0]                 
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 300)          90300       dense_1[0][0]                    
__________________________________________________________________________________________________
dense_6 (Dense)                 (None, 300)          90300       dense_5[0][0]                    
__________________________________________________________________________________________________
dense_3 (Dense)                 (None, 300)          90300       dense_2[0][0]                    
__________________________________________________________________________________________________
dense_7 (Dense)                 (None, 300)          90300       dense_6[0][0]                    
__________________________________________________________________________________________________
dense_4 (Dense)                 (None, 128)          38528       dense_3[0][0]                    
__________________________________________________________________________________________________
dense_8 (Dense)                 (None, 128)          38528       dense_7[0][0]                    
__________________________________________________________________________________________________
dot_1 (Dot)                     (None, 1)            0           dense_4[0][0]                    
                                                                 dense_8[0][0]                    
__________________________________________________________________________________________________
dense_9 (Dense)                 (None, 1)            2           dot_1[0][0]                      
==================================================================================================
Total params: 6,225,858
Trainable params: 6,225,858
Non-trainable params: 0
__________________________________________________________________________________________________

In [8]:
pred_x, pred_y = test_pack_processed[:].unpack()
evaluate = mz.callbacks.EvaluateAllMetrics(model, x=pred_x, y=pred_y, batch_size=len(pred_x))

In [11]:
train_generator = mz.PairDataGenerator(train_pack_processed, num_dup=1, num_neg=4, batch_size=32, shuffle=True)
len(train_generator)


WARNING: PairDataGenerator will be deprecated in MatchZoo v2.2. Use `DataGenerator` with callbacks instead.
Out[11]:
32

In [12]:
history = model.fit_generator(train_generator, epochs=20, callbacks=[evaluate], workers=5, use_multiprocessing=False)


Epoch 1/20
32/32 [==============================] - 7s 215ms/step - loss: 1.3325
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4431849853601904 - normalized_discounted_cumulative_gain@5(0.0): 0.5295386323998266 - mean_average_precision(0.0): 0.48303488812718776
Epoch 2/20
32/32 [==============================] - 6s 176ms/step - loss: 1.3159
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4353814849661657 - normalized_discounted_cumulative_gain@5(0.0): 0.5032525911610362 - mean_average_precision(0.0): 0.4776049822282439
Epoch 3/20
32/32 [==============================] - 5s 171ms/step - loss: 1.2955
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4088637099689691 - normalized_discounted_cumulative_gain@5(0.0): 0.48351010067595823 - mean_average_precision(0.0): 0.4432379861560312
Epoch 4/20
32/32 [==============================] - 6s 173ms/step - loss: 1.2726
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.46569627211992487 - normalized_discounted_cumulative_gain@5(0.0): 0.5305277638291452 - mean_average_precision(0.0): 0.4903964896023526
Epoch 5/20
32/32 [==============================] - 6s 172ms/step - loss: 1.2439
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.44778538209256513 - normalized_discounted_cumulative_gain@5(0.0): 0.5104380434420628 - mean_average_precision(0.0): 0.47615129143046664
Epoch 6/20
32/32 [==============================] - 6s 172ms/step - loss: 1.2202
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4452573045503587 - normalized_discounted_cumulative_gain@5(0.0): 0.5137975378931312 - mean_average_precision(0.0): 0.4742872412051932
Epoch 7/20
32/32 [==============================] - 5s 170ms/step - loss: 1.2038
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.41264292792428936 - normalized_discounted_cumulative_gain@5(0.0): 0.4740615140630128 - mean_average_precision(0.0): 0.45294026408574084
Epoch 8/20
32/32 [==============================] - 6s 172ms/step - loss: 1.1848
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.45527149721829696 - normalized_discounted_cumulative_gain@5(0.0): 0.5229678873030444 - mean_average_precision(0.0): 0.48490323375232625
Epoch 9/20
32/32 [==============================] - 5s 171ms/step - loss: 1.1504 3
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4401749964298954 - normalized_discounted_cumulative_gain@5(0.0): 0.5202410581724496 - mean_average_precision(0.0): 0.47967943778482564
Epoch 10/20
32/32 [==============================] - 5s 172ms/step - loss: 1.1314
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.44883790476151675 - normalized_discounted_cumulative_gain@5(0.0): 0.5215788412779597 - mean_average_precision(0.0): 0.48274548802838624
Epoch 11/20
32/32 [==============================] - 6s 173ms/step - loss: 1.1109
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.45835958548802597 - normalized_discounted_cumulative_gain@5(0.0): 0.5254562351939174 - mean_average_precision(0.0): 0.48819163523037407
Epoch 12/20
32/32 [==============================] - 6s 174ms/step - loss: 1.0915
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4540812972538116 - normalized_discounted_cumulative_gain@5(0.0): 0.502728792326375 - mean_average_precision(0.0): 0.48229166522394096
Epoch 13/20
32/32 [==============================] - 6s 173ms/step - loss: 1.0805
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4462255256302118 - normalized_discounted_cumulative_gain@5(0.0): 0.5097488218798687 - mean_average_precision(0.0): 0.4751972950775518
Epoch 14/20
32/32 [==============================] - 6s 174ms/step - loss: 1.0575
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4263585495923587 - normalized_discounted_cumulative_gain@5(0.0): 0.5014903707963352 - mean_average_precision(0.0): 0.46364289738480496
Epoch 15/20
32/32 [==============================] - 6s 179ms/step - loss: 1.0396
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.43936705108731194 - normalized_discounted_cumulative_gain@5(0.0): 0.5218713927469146 - mean_average_precision(0.0): 0.47233172236473137
Epoch 16/20
32/32 [==============================] - 6s 182ms/step - loss: 1.0156
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.45080782122574514 - normalized_discounted_cumulative_gain@5(0.0): 0.5181271382497495 - mean_average_precision(0.0): 0.4832342072703635
Epoch 17/20
32/32 [==============================] - 6s 175ms/step - loss: 0.9932
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.423108628561739 - normalized_discounted_cumulative_gain@5(0.0): 0.49596605935842625 - mean_average_precision(0.0): 0.4667294180948952
Epoch 18/20
32/32 [==============================] - 5s 172ms/step - loss: 0.9800
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4378084124127128 - normalized_discounted_cumulative_gain@5(0.0): 0.5098753091251295 - mean_average_precision(0.0): 0.4734416114488085
Epoch 19/20
32/32 [==============================] - 6s 172ms/step - loss: 0.9662
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4504450479915345 - normalized_discounted_cumulative_gain@5(0.0): 0.519107636100811 - mean_average_precision(0.0): 0.48712867088141415
Epoch 20/20
32/32 [==============================] - 6s 172ms/step - loss: 0.9512
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.45663442312293695 - normalized_discounted_cumulative_gain@5(0.0): 0.5363645153841258 - mean_average_precision(0.0): 0.4956098197015037

In [ ]: