In [1]:
%run init.ipynb
Using TensorFlow backend.
/home/fanyixing/.local/python3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88
return f(*args, **kwds)
/home/fanyixing/.local/python3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88
return f(*args, **kwds)
matchzoo version 2.1.0
data loading ...
data loaded as `train_pack_raw` `dev_pack_raw` `test_pack_raw`
`ranking_task` initialized with metrics [normalized_discounted_cumulative_gain@3(0.0), normalized_discounted_cumulative_gain@5(0.0), mean_average_precision(0.0)]
loading embedding ...
embedding loaded as `glove_embedding`
In [2]:
preprocessor = mz.preprocessors.BasicPreprocessor(fixed_length_left=10, fixed_length_right=40, remove_stop_words=False)
train_pack_processed = preprocessor.fit_transform(train_pack_raw)
dev_pack_processed = preprocessor.transform(dev_pack_raw)
test_pack_processed = preprocessor.transform(test_pack_raw)
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 3724.89it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 18841/18841 [00:07<00:00, 2470.05it/s]
Processing text_right with append: 100%|██████████| 18841/18841 [00:00<00:00, 445941.69it/s]
Building FrequencyFilter from a datapack.: 100%|██████████| 18841/18841 [00:00<00:00, 70474.81it/s]
Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 47716.48it/s]
Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 342940.70it/s]
Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 379138.05it/s]
Building Vocabulary from a datapack.: 100%|██████████| 404415/404415 [00:00<00:00, 1847244.28it/s]
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 4067.10it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 18841/18841 [00:07<00:00, 2516.30it/s]
Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 67851.85it/s]
Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 89496.74it/s]
Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 75255.27it/s]
Processing length_left with len: 100%|██████████| 2118/2118 [00:00<00:00, 356739.86it/s]
Processing length_right with len: 100%|██████████| 18841/18841 [00:00<00:00, 400286.10it/s]
Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 42205.68it/s]
Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 36443.25it/s]
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 122/122 [00:00<00:00, 4076.75it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 1115/1115 [00:00<00:00, 2482.03it/s]
Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 68513.29it/s]
Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 74397.37it/s]
Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 81470.46it/s]
Processing length_left with len: 100%|██████████| 122/122 [00:00<00:00, 124654.10it/s]
Processing length_right with len: 100%|██████████| 1115/1115 [00:00<00:00, 340342.69it/s]
Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 36081.31it/s]
Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 34092.08it/s]
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 237/237 [00:00<00:00, 3945.55it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2300/2300 [00:01<00:00, 2229.53it/s]
Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 69715.12it/s]
Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 91685.12it/s]
Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 81219.95it/s]
Processing length_left with len: 100%|██████████| 237/237 [00:00<00:00, 168597.36it/s]
Processing length_right with len: 100%|██████████| 2300/2300 [00:00<00:00, 361686.38it/s]
Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 36193.34it/s]
Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 37573.42it/s]
In [3]:
preprocessor.context
Out[3]:
{'filter_unit': <matchzoo.preprocessors.units.frequency_filter.FrequencyFilter at 0x7f4c81abd208>,
'vocab_unit': <matchzoo.preprocessors.units.vocabulary.Vocabulary at 0x7f4c7e0039b0>,
'vocab_size': 16674,
'embedding_input_dim': 16674,
'input_shapes': [(10,), (40,)]}
In [4]:
model = mz.models.ConvKNRM()
model.params.update(preprocessor.context)
model.params['task'] = ranking_task
model.params['embedding_output_dim'] = glove_embedding.output_dim
model.params['embedding_trainable'] = True
model.params['filters'] = 128
model.params['conv_activation_func'] = 'tanh'
model.params['max_ngram'] = 3
model.params['use_crossmatch'] = True
model.params['kernel_num'] = 11
model.params['sigma'] = 0.1
model.params['exact_sigma'] = 0.001
model.params['optimizer'] = 'adadelta'
model.build()
model.compile()
#model.backend.summary()
In [5]:
embedding_matrix = glove_embedding.build_matrix(preprocessor.context['vocab_unit'].state['term_index'])
model.load_embedding_matrix(embedding_matrix)
In [6]:
pred_x, pred_y = test_pack_processed.unpack()
evaluate = mz.callbacks.EvaluateAllMetrics(model, x=pred_x, y=pred_y, batch_size=len(pred_y))
In [7]:
train_generator = mz.DataGenerator(
train_pack_processed,
mode='pair',
num_dup=5,
num_neg=1,
batch_size=20
)
print('num batches:', len(train_generator))
num batches: 255
In [8]:
history = model.fit_generator(train_generator, epochs=30, callbacks=[evaluate], workers=30, use_multiprocessing=True)
Epoch 1/30
255/255 [==============================] - 39s 151ms/step - loss: 0.4776
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6046833551308459 - normalized_discounted_cumulative_gain@5(0.0): 0.6665959185697871 - mean_average_precision(0.0): 0.6129317182476187
Epoch 2/30
255/255 [==============================] - 42s 165ms/step - loss: 0.1211
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6152022692439065 - normalized_discounted_cumulative_gain@5(0.0): 0.6744661220937255 - mean_average_precision(0.0): 0.6343442709387471
Epoch 3/30
255/255 [==============================] - 43s 167ms/step - loss: 0.0620
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5548312155628421 - normalized_discounted_cumulative_gain@5(0.0): 0.6213976803862951 - mean_average_precision(0.0): 0.5710894912635418
Epoch 4/30
255/255 [==============================] - 43s 168ms/step - loss: 0.0417
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5954135056241098 - normalized_discounted_cumulative_gain@5(0.0): 0.6578139885107471 - mean_average_precision(0.0): 0.6111818707546555
Epoch 5/30
255/255 [==============================] - 42s 165ms/step - loss: 0.0356
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6028110339356899 - normalized_discounted_cumulative_gain@5(0.0): 0.6587081531783964 - mean_average_precision(0.0): 0.609672147851597
Epoch 6/30
255/255 [==============================] - 43s 167ms/step - loss: 0.0168
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6162374354841582 - normalized_discounted_cumulative_gain@5(0.0): 0.6730370811514368 - mean_average_precision(0.0): 0.6224984341793723
Epoch 7/30
255/255 [==============================] - 44s 171ms/step - loss: 0.0173
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6140235133098497 - normalized_discounted_cumulative_gain@5(0.0): 0.6702416124777638 - mean_average_precision(0.0): 0.6291591558219198
Epoch 8/30
255/255 [==============================] - 42s 167ms/step - loss: 0.0105
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6170925436547549 - normalized_discounted_cumulative_gain@5(0.0): 0.6721917042350951 - mean_average_precision(0.0): 0.6319963125925817
Epoch 9/30
255/255 [==============================] - 44s 173ms/step - loss: 0.0093
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6128277115636099 - normalized_discounted_cumulative_gain@5(0.0): 0.6723112242287127 - mean_average_precision(0.0): 0.6202358378889221
Epoch 10/30
255/255 [==============================] - 44s 172ms/step - loss: 0.0074
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6014563078514086 - normalized_discounted_cumulative_gain@5(0.0): 0.6534368691098246 - mean_average_precision(0.0): 0.6111100251131897
Epoch 11/30
255/255 [==============================] - 44s 174ms/step - loss: 0.0040
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6227886270617885 - normalized_discounted_cumulative_gain@5(0.0): 0.671961802050252 - mean_average_precision(0.0): 0.6294700258349492
Epoch 12/30
255/255 [==============================] - 43s 170ms/step - loss: 0.0049
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6090539823114218 - normalized_discounted_cumulative_gain@5(0.0): 0.6657954120017808 - mean_average_precision(0.0): 0.6204152162428167
Epoch 13/30
255/255 [==============================] - 44s 172ms/step - loss: 0.0045
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.605268576648384 - normalized_discounted_cumulative_gain@5(0.0): 0.6743645292160183 - mean_average_precision(0.0): 0.6225096268464414
Epoch 14/30
255/255 [==============================] - 44s 173ms/step - loss: 0.0056
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6094653731796733 - normalized_discounted_cumulative_gain@5(0.0): 0.6626866921233837 - mean_average_precision(0.0): 0.6167465211769009
Epoch 15/30
255/255 [==============================] - 44s 174ms/step - loss: 0.0020
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.636181345906811 - normalized_discounted_cumulative_gain@5(0.0): 0.6839535736459206 - mean_average_precision(0.0): 0.6396845610127441
Epoch 16/30
255/255 [==============================] - 44s 171ms/step - loss: 0.0020
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6242068770913558 - normalized_discounted_cumulative_gain@5(0.0): 0.6742098356127264 - mean_average_precision(0.0): 0.632390750330996
Epoch 17/30
255/255 [==============================] - 44s 173ms/step - loss: 9.0911e-04
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6305977929317165 - normalized_discounted_cumulative_gain@5(0.0): 0.6760885220490657 - mean_average_precision(0.0): 0.6345498605597264
Epoch 18/30
255/255 [==============================] - 44s 172ms/step - loss: 5.5907e-04
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6302281364338485 - normalized_discounted_cumulative_gain@5(0.0): 0.6801281668512741 - mean_average_precision(0.0): 0.6365815560435815
Epoch 19/30
255/255 [==============================] - 44s 174ms/step - loss: 0.0011
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6161750510905308 - normalized_discounted_cumulative_gain@5(0.0): 0.6828588425017452 - mean_average_precision(0.0): 0.6399974866560868
Epoch 20/30
255/255 [==============================] - 44s 173ms/step - loss: 5.9813e-04
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6192304771389455 - normalized_discounted_cumulative_gain@5(0.0): 0.6835492157788869 - mean_average_precision(0.0): 0.6339323846467368
Epoch 21/30
255/255 [==============================] - 43s 170ms/step - loss: 9.7819e-04
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6306940079499387 - normalized_discounted_cumulative_gain@5(0.0): 0.6811198627999581 - mean_average_precision(0.0): 0.6436540061956451
Epoch 22/30
255/255 [==============================] - 44s 171ms/step - loss: 7.8472e-04
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6246823923728032 - normalized_discounted_cumulative_gain@5(0.0): 0.6828731850846279 - mean_average_precision(0.0): 0.6295758423592317
Epoch 27/30
255/255 [==============================] - 44s 172ms/step - loss: 3.4582e-04
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6202967663352117 - normalized_discounted_cumulative_gain@5(0.0): 0.6718788743837789 - mean_average_precision(0.0): 0.6227787543285545
Epoch 28/30
255/255 [==============================] - 45s 177ms/step - loss: 0.0011
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6221372690184138 - normalized_discounted_cumulative_gain@5(0.0): 0.6775309597373549 - mean_average_precision(0.0): 0.6240701328331634
Epoch 29/30
255/255 [==============================] - 45s 177ms/step - loss: 9.9084e-04
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6193374900161305 - normalized_discounted_cumulative_gain@5(0.0): 0.6749664782289659 - mean_average_precision(0.0): 0.6316795270723043
Epoch 30/30
255/255 [==============================] - 45s 177ms/step - loss: 4.5966e-04
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6325513939822278 - normalized_discounted_cumulative_gain@5(0.0): 0.6801122164684699 - mean_average_precision(0.0): 0.6351265617879543
In [ ]:
Content source: faneshion/MatchZoo
Similar notebooks: