In [1]:
%run init.ipynb


Using TensorFlow backend.
/home/fanyixing/.local/python3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88
  return f(*args, **kwds)
/home/fanyixing/.local/python3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88
  return f(*args, **kwds)
matchzoo version 2.1.0

data loading ...
data loaded as `train_pack_raw` `dev_pack_raw` `test_pack_raw`
`ranking_task` initialized with metrics [normalized_discounted_cumulative_gain@3(0.0), normalized_discounted_cumulative_gain@5(0.0), mean_average_precision(0.0)]
loading embedding ...
embedding loaded as `glove_embedding`

In [2]:
preprocessor = mz.preprocessors.BasicPreprocessor(fixed_length_left=10, fixed_length_right=100, remove_stop_words=False)
train_pack_processed = preprocessor.fit_transform(train_pack_raw)
dev_pack_processed = preprocessor.transform(dev_pack_raw)
test_pack_processed = preprocessor.transform(test_pack_raw)


Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 3718.40it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 18841/18841 [00:07<00:00, 2381.54it/s]
Processing text_right with append: 100%|██████████| 18841/18841 [00:00<00:00, 444321.96it/s]
Building FrequencyFilter from a datapack.: 100%|██████████| 18841/18841 [00:00<00:00, 66270.86it/s]
Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 42664.36it/s]
Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 330095.71it/s]
Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 373073.88it/s]
Building Vocabulary from a datapack.: 100%|██████████| 404415/404415 [00:00<00:00, 1847527.98it/s]
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 3771.43it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 18841/18841 [00:07<00:00, 2367.53it/s]
Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 69352.98it/s]
Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 97079.34it/s]
Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 74474.07it/s]
Processing length_left with len: 100%|██████████| 2118/2118 [00:00<00:00, 334659.48it/s]
Processing length_right with len: 100%|██████████| 18841/18841 [00:00<00:00, 414870.15it/s]
Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 43066.26it/s]
Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 32836.35it/s]
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 122/122 [00:00<00:00, 3487.84it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 1115/1115 [00:00<00:00, 2442.70it/s]
Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 67645.17it/s]
Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 44196.33it/s]
Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 76927.43it/s]
Processing length_left with len: 100%|██████████| 122/122 [00:00<00:00, 122242.02it/s]
Processing length_right with len: 100%|██████████| 1115/1115 [00:00<00:00, 333332.07it/s]
Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 35858.80it/s]
Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 28605.98it/s]
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 237/237 [00:00<00:00, 3830.90it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2300/2300 [00:01<00:00, 2071.67it/s]
Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 70304.99it/s]
Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 85642.29it/s]
Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 81757.54it/s]
Processing length_left with len: 100%|██████████| 237/237 [00:00<00:00, 145883.48it/s]
Processing length_right with len: 100%|██████████| 2300/2300 [00:00<00:00, 372769.40it/s]
Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 35687.87it/s]
Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 33140.49it/s]

In [3]:
preprocessor.context


Out[3]:
{'filter_unit': <matchzoo.preprocessors.units.frequency_filter.FrequencyFilter at 0x7faf25c75f98>,
 'vocab_unit': <matchzoo.preprocessors.units.vocabulary.Vocabulary at 0x7faf264499b0>,
 'vocab_size': 16674,
 'embedding_input_dim': 16674,
 'input_shapes': [(10,), (100,)]}

In [4]:
ranking_task = mz.tasks.Ranking(loss=mz.losses.RankCrossEntropyLoss(num_neg=10))
ranking_task.metrics = [
    mz.metrics.NormalizedDiscountedCumulativeGain(k=3),
    mz.metrics.NormalizedDiscountedCumulativeGain(k=5),
    mz.metrics.MeanAveragePrecision()
]

In [5]:
bin_size = 30
model = mz.models.DRMM()
model.params.update(preprocessor.context)
model.params['input_shapes'] = [[10,], [10, bin_size,]]
model.params['task'] = ranking_task
model.params['mask_value'] = 0
model.params['embedding_output_dim'] = glove_embedding.output_dim
model.params['mlp_num_layers'] = 1
model.params['mlp_num_units'] = 10
model.params['mlp_num_fan_out'] = 1
model.params['mlp_activation_func'] = 'tanh'
model.params['optimizer'] = 'adadelta'
model.build()
model.compile()
model.backend.summary()


__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
text_left (InputLayer)          (None, 10)           0                                            
__________________________________________________________________________________________________
embedding (Embedding)           (None, 10, 300)      5002200     text_left[0][0]                  
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 10, 1)        300         embedding[0][0]                  
__________________________________________________________________________________________________
match_histogram (InputLayer)    (None, 10, 30)       0                                            
__________________________________________________________________________________________________
attention_mask (Lambda)         (None, 10, 1)        0           dense_1[0][0]                    
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 10, 10)       310         match_histogram[0][0]            
__________________________________________________________________________________________________
attention_probs (Lambda)        (None, 10, 1)        0           attention_mask[0][0]             
__________________________________________________________________________________________________
dense_3 (Dense)                 (None, 10, 1)        11          dense_2[0][0]                    
__________________________________________________________________________________________________
dot_1 (Dot)                     (None, 1, 1)         0           attention_probs[0][0]            
                                                                 dense_3[0][0]                    
__________________________________________________________________________________________________
flatten_1 (Flatten)             (None, 1)            0           dot_1[0][0]                      
__________________________________________________________________________________________________
dense_4 (Dense)                 (None, 1)            2           flatten_1[0][0]                  
==================================================================================================
Total params: 5,002,823
Trainable params: 5,002,823
Non-trainable params: 0
__________________________________________________________________________________________________

In [6]:
embedding_matrix = glove_embedding.build_matrix(preprocessor.context['vocab_unit'].state['term_index'])
# normalize the word embedding for fast histogram generating.
l2_norm = np.sqrt((embedding_matrix*embedding_matrix).sum(axis=1))
embedding_matrix = embedding_matrix / l2_norm[:, np.newaxis]
model.load_embedding_matrix(embedding_matrix)

In [7]:
hist_callback = mz.data_generator.callbacks.Histogram(embedding_matrix, bin_size=30, hist_mode='LCH')

In [8]:
pred_generator = mz.DataGenerator(test_pack_processed, mode='point', callbacks=[hist_callback])
pred_x, pred_y = pred_generator[:]
evaluate = mz.callbacks.EvaluateAllMetrics(model, 
                                           x=pred_x, 
                                           y=pred_y, 
                                           once_every=1, 
                                           batch_size=len(pred_y),
                                           model_save_path='./drmm_pretrained_model/'
                                          )

In [9]:
train_generator = mz.DataGenerator(train_pack_processed, mode='pair', num_dup=5, num_neg=10, batch_size=20, 
                                   callbacks=[hist_callback])
print('num batches:', len(train_generator))


num batches: 255

In [10]:
history = model.fit_generator(train_generator, epochs=30, callbacks=[evaluate], workers=30, use_multiprocessing=True)


Epoch 1/30
255/255 [==============================] - 29s 113ms/step - loss: 2.2520
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.565716344425806 - normalized_discounted_cumulative_gain@5(0.0): 0.6337418608669659 - mean_average_precision(0.0): 0.5867500331707677
Epoch 2/30
255/255 [==============================] - 44s 172ms/step - loss: 1.9129
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5985553932954132 - normalized_discounted_cumulative_gain@5(0.0): 0.6538053305162407 - mean_average_precision(0.0): 0.6119736749640002
Epoch 3/30
255/255 [==============================] - 44s 172ms/step - loss: 1.5735
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6009435695818752 - normalized_discounted_cumulative_gain@5(0.0): 0.6639328074555028 - mean_average_precision(0.0): 0.6141421880590333
Epoch 4/30
255/255 [==============================] - 44s 174ms/step - loss: 1.3388
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.60611654032328 - normalized_discounted_cumulative_gain@5(0.0): 0.657932990992144 - mean_average_precision(0.0): 0.6164241968035815
Epoch 5/30
255/255 [==============================] - 44s 173ms/step - loss: 1.2257
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6038327788753987 - normalized_discounted_cumulative_gain@5(0.0): 0.658207699559794 - mean_average_precision(0.0): 0.6165270742594192
Epoch 6/30
255/255 [==============================] - 45s 175ms/step - loss: 1.1705
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5996133695926983 - normalized_discounted_cumulative_gain@5(0.0): 0.6492549188602507 - mean_average_precision(0.0): 0.610038711779037
Epoch 7/30
255/255 [==============================] - 49s 194ms/step - loss: 1.1245
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.596951218733747 - normalized_discounted_cumulative_gain@5(0.0): 0.6537028205494899 - mean_average_precision(0.0): 0.6118594078618916
Epoch 8/30
255/255 [==============================] - 49s 192ms/step - loss: 1.0965
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.607384889011673 - normalized_discounted_cumulative_gain@5(0.0): 0.6591796919495273 - mean_average_precision(0.0): 0.6198677092456238
Epoch 9/30
255/255 [==============================] - 50s 196ms/step - loss: 1.0739
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6009170916196322 - normalized_discounted_cumulative_gain@5(0.0): 0.6585982395889274 - mean_average_precision(0.0): 0.61618597649167
Epoch 10/30
255/255 [==============================] - 50s 196ms/step - loss: 1.0617
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6033747474109487 - normalized_discounted_cumulative_gain@5(0.0): 0.6578329756640785 - mean_average_precision(0.0): 0.6179567226451466
Epoch 11/30
255/255 [==============================] - 51s 199ms/step - loss: 1.0452
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5997577679210652 - normalized_discounted_cumulative_gain@5(0.0): 0.653391959851057 - mean_average_precision(0.0): 0.6107675669593364
Epoch 12/30
255/255 [==============================] - 50s 197ms/step - loss: 1.0463
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6099707901807351 - normalized_discounted_cumulative_gain@5(0.0): 0.6576795760995855 - mean_average_precision(0.0): 0.6190800467908035
Epoch 13/30
255/255 [==============================] - 48s 189ms/step - loss: 1.0142
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6060017367789287 - normalized_discounted_cumulative_gain@5(0.0): 0.6561934627591615 - mean_average_precision(0.0): 0.6164463654573828
Epoch 14/30
255/255 [==============================] - 48s 190ms/step - loss: 1.0166
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5978152368225886 - normalized_discounted_cumulative_gain@5(0.0): 0.6505786036930443 - mean_average_precision(0.0): 0.61033087775462
Epoch 15/30
255/255 [==============================] - 49s 192ms/step - loss: 0.9978
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5922250615836083 - normalized_discounted_cumulative_gain@5(0.0): 0.6426732309894715 - mean_average_precision(0.0): 0.6022889901471488
Epoch 16/30
255/255 [==============================] - 49s 191ms/step - loss: 1.0025
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5750131412934796 - normalized_discounted_cumulative_gain@5(0.0): 0.6352853673494635 - mean_average_precision(0.0): 0.5900977589685756
Epoch 17/30
255/255 [==============================] - 51s 199ms/step - loss: 0.9862
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5758042381295027 - normalized_discounted_cumulative_gain@5(0.0): 0.6346575981857837 - mean_average_precision(0.0): 0.5901888624427178
Epoch 18/30
255/255 [==============================] - 49s 193ms/step - loss: 0.9855
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.579621264998872 - normalized_discounted_cumulative_gain@5(0.0): 0.6347686906366622 - mean_average_precision(0.0): 0.5884915666561875
Epoch 19/30
255/255 [==============================] - 49s 192ms/step - loss: 0.9766
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5828950659531305 - normalized_discounted_cumulative_gain@5(0.0): 0.6372592629514049 - mean_average_precision(0.0): 0.5918219888593219
Epoch 20/30
255/255 [==============================] - 48s 190ms/step - loss: 0.9680
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5825563351056543 - normalized_discounted_cumulative_gain@5(0.0): 0.6391007734169843 - mean_average_precision(0.0): 0.5938890573634718
Epoch 21/30
255/255 [==============================] - 49s 191ms/step - loss: 0.9508
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.579128022658977 - normalized_discounted_cumulative_gain@5(0.0): 0.6395338143562516 - mean_average_precision(0.0): 0.5941989088930982
Epoch 22/30
255/255 [==============================] - 50s 196ms/step - loss: 0.9522
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5757112036272017 - normalized_discounted_cumulative_gain@5(0.0): 0.6293793465362543 - mean_average_precision(0.0): 0.583203117231657
Epoch 23/30
255/255 [==============================] - 52s 203ms/step - loss: 0.9465
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5712280953991604 - normalized_discounted_cumulative_gain@5(0.0): 0.633607426512819 - mean_average_precision(0.0): 0.583670818529751
Epoch 24/30
255/255 [==============================] - 51s 200ms/step - loss: 0.9434
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.580920539337657 - normalized_discounted_cumulative_gain@5(0.0): 0.641607241519095 - mean_average_precision(0.0): 0.5944285471787385
Epoch 25/30
255/255 [==============================] - 51s 198ms/step - loss: 0.9410
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5804534006493391 - normalized_discounted_cumulative_gain@5(0.0): 0.6365914442161427 - mean_average_precision(0.0): 0.5909732454825581
Epoch 26/30
255/255 [==============================] - 48s 190ms/step - loss: 0.9288
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5706608764817882 - normalized_discounted_cumulative_gain@5(0.0): 0.6287030620891625 - mean_average_precision(0.0): 0.5812929865633114
Epoch 27/30
255/255 [==============================] - 49s 193ms/step - loss: 0.9390
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5609829297013442 - normalized_discounted_cumulative_gain@5(0.0): 0.6256839520090155 - mean_average_precision(0.0): 0.5787581434265507
Epoch 28/30
255/255 [==============================] - 49s 191ms/step - loss: 0.9276
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5719190626216555 - normalized_discounted_cumulative_gain@5(0.0): 0.6307496688555204 - mean_average_precision(0.0): 0.5817515076104938
Epoch 29/30
255/255 [==============================] - 49s 193ms/step - loss: 0.9183
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5795917195993303 - normalized_discounted_cumulative_gain@5(0.0): 0.629216010151653 - mean_average_precision(0.0): 0.5871250977359159
Epoch 30/30
255/255 [==============================] - 51s 200ms/step - loss: 0.9114
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5699846325745904 - normalized_discounted_cumulative_gain@5(0.0): 0.6288458929530589 - mean_average_precision(0.0): 0.581158143460481

In [11]:
drmm_model = mz.load_model('./drmm_pretrained_model/16')
test_generator = mz.DataGenerator(data_pack=dev_pack_processed[:10], mode='point', callbacks=[hist_callback])
test_x, test_y = test_generator[:]
prediction = drmm_model.predict(test_x)
prediction


Out[11]:
array([[ 2.1248043 ],
       [-0.7405751 ],
       [ 1.2434225 ],
       [-3.161619  ],
       [-1.5212955 ],
       [-1.3774216 ],
       [-0.43242887],
       [-3.8191142 ],
       [-2.490344  ],
       [ 1.9016179 ]], dtype=float32)

In [12]:
import shutil
shutil.rmtree('./drmm_pretrained_model/')

In [ ]: