In [2]:
%run init.ipynb
Using TensorFlow backend.
/home/fanyixing/.local/python3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88
return f(*args, **kwds)
/home/fanyixing/.local/python3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88
return f(*args, **kwds)
matchzoo version 2.1.0
data loading ...
data loaded as `train_pack_raw` `dev_pack_raw` `test_pack_raw`
`ranking_task` initialized with metrics [normalized_discounted_cumulative_gain@3(0.0), normalized_discounted_cumulative_gain@5(0.0), mean_average_precision(0.0)]
loading embedding ...
embedding loaded as `glove_embedding`
In [3]:
preprocessor = mz.preprocessors.BasicPreprocessor(fixed_length_left=10, fixed_length_right=40, remove_stop_words=False)
train_pack_processed = preprocessor.fit_transform(train_pack_raw)
valid_pack_processed = preprocessor.transform(dev_pack_raw)
test_pack_processed = preprocessor.transform(test_pack_raw)
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 3684.04it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 18841/18841 [00:07<00:00, 2461.23it/s]
Processing text_right with append: 100%|██████████| 18841/18841 [00:00<00:00, 461585.85it/s]
Building FrequencyFilter from a datapack.: 100%|██████████| 18841/18841 [00:00<00:00, 71870.86it/s]
Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 38605.84it/s]
Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 283592.53it/s]
Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 346305.69it/s]
Building Vocabulary from a datapack.: 100%|██████████| 404415/404415 [00:00<00:00, 1310721.01it/s]
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 4026.58it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 18841/18841 [00:08<00:00, 2330.13it/s]
Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 62427.97it/s]
Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 76941.04it/s]
Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 60446.57it/s]
Processing length_left with len: 100%|██████████| 2118/2118 [00:00<00:00, 336141.06it/s]
Processing length_right with len: 100%|██████████| 18841/18841 [00:00<00:00, 395438.78it/s]
Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 44951.02it/s]
Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 37867.11it/s]
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 122/122 [00:00<00:00, 4134.42it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 1115/1115 [00:00<00:00, 2569.53it/s]
Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 54590.38it/s]
Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 76373.89it/s]
Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 7920.67it/s]
Processing length_left with len: 100%|██████████| 122/122 [00:00<00:00, 99033.31it/s]
Processing length_right with len: 100%|██████████| 1115/1115 [00:00<00:00, 333903.25it/s]
Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 38113.00it/s]
Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 39241.53it/s]
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 237/237 [00:00<00:00, 4120.57it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2300/2300 [00:00<00:00, 2342.51it/s]
Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 51395.86it/s]
Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 74881.36it/s]
Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 76694.17it/s]
Processing length_left with len: 100%|██████████| 237/237 [00:00<00:00, 177034.74it/s]
Processing length_right with len: 100%|██████████| 2300/2300 [00:00<00:00, 367123.31it/s]
Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 41574.66it/s]
Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 37033.81it/s]
In [6]:
model = mz.models.MVLSTM()
model.params.update(preprocessor.context)
model.params['task'] = ranking_task
model.params['embedding_output_dim'] = 300
model.params['lstm_units'] = 50
model.params['top_k'] = 20
model.params['mlp_num_layers'] = 2
model.params['mlp_num_units'] = 10
model.params['mlp_num_fan_out'] = 5
model.params['mlp_activation_func'] = 'relu'
model.params['dropout_rate'] = 0.5
model.params['optimizer'] = 'adadelta'
model.guess_and_fill_missing_params()
model.build()
model.compile()
model.backend.summary()
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
text_left (InputLayer) (None, 10) 0
__________________________________________________________________________________________________
text_right (InputLayer) (None, 40) 0
__________________________________________________________________________________________________
embedding (Embedding) multiple 5002200 text_left[0][0]
text_right[0][0]
__________________________________________________________________________________________________
bidirectional_3 (Bidirectional) (None, 10, 100) 140400 embedding[0][0]
__________________________________________________________________________________________________
bidirectional_4 (Bidirectional) (None, 40, 100) 140400 embedding[1][0]
__________________________________________________________________________________________________
dot_2 (Dot) (None, 10, 40) 0 bidirectional_3[0][0]
bidirectional_4[0][0]
__________________________________________________________________________________________________
reshape_2 (Reshape) (None, 400) 0 dot_2[0][0]
__________________________________________________________________________________________________
lambda_4 (Lambda) (None, 20) 0 reshape_2[0][0]
__________________________________________________________________________________________________
dense_5 (Dense) (None, 10) 210 lambda_4[0][0]
__________________________________________________________________________________________________
dense_6 (Dense) (None, 10) 110 dense_5[0][0]
__________________________________________________________________________________________________
dense_7 (Dense) (None, 5) 55 dense_6[0][0]
__________________________________________________________________________________________________
dropout_2 (Dropout) (None, 5) 0 dense_7[0][0]
__________________________________________________________________________________________________
dense_8 (Dense) (None, 1) 6 dropout_2[0][0]
==================================================================================================
Total params: 5,283,381
Trainable params: 5,283,381
Non-trainable params: 0
__________________________________________________________________________________________________
In [7]:
embedding_matrix = glove_embedding.build_matrix(preprocessor.context['vocab_unit'].state['term_index'])
model.load_embedding_matrix(embedding_matrix)
In [8]:
pred_x, pred_y = test_pack_processed.unpack()
evaluate = mz.callbacks.EvaluateAllMetrics(model, x=pred_x, y=pred_y, batch_size=len(pred_y))
In [9]:
train_generator = mz.DataGenerator(
train_pack_processed,
mode='pair',
num_dup=2,
num_neg=1,
batch_size=20
)
print('num batches:', len(train_generator))
num batches: 102
In [10]:
history = model.fit_generator(train_generator, epochs=30, callbacks=[evaluate], workers=30, use_multiprocessing=True)
Epoch 1/30
102/102 [==============================] - 24s 235ms/step - loss: 1.0025
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.595728243469325 - normalized_discounted_cumulative_gain@5(0.0): 0.6453087410611237 - mean_average_precision(0.0): 0.6055350793503939
Epoch 2/30
102/102 [==============================] - 21s 208ms/step - loss: 1.0002
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6343952744869739 - normalized_discounted_cumulative_gain@5(0.0): 0.6803687365482167 - mean_average_precision(0.0): 0.6365944334270842
Epoch 3/30
102/102 [==============================] - 22s 213ms/step - loss: 0.9999
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6321468863776336 - normalized_discounted_cumulative_gain@5(0.0): 0.6792345599265245 - mean_average_precision(0.0): 0.6351879636661841
Epoch 4/30
102/102 [==============================] - 21s 208ms/step - loss: 0.9999
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6039431841160746 - normalized_discounted_cumulative_gain@5(0.0): 0.6476652529714161 - mean_average_precision(0.0): 0.6115127602489608
Epoch 5/30
102/102 [==============================] - 22s 214ms/step - loss: 1.0003
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6279274770949332 - normalized_discounted_cumulative_gain@5(0.0): 0.6782797312955924 - mean_average_precision(0.0): 0.6316855781831582
Epoch 6/30
102/102 [==============================] - 21s 211ms/step - loss: 0.9999
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6305896279538845 - normalized_discounted_cumulative_gain@5(0.0): 0.6776773015027755 - mean_average_precision(0.0): 0.633078259024834
Epoch 7/30
102/102 [==============================] - 22s 219ms/step - loss: 1.0000
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6305896279538845 - normalized_discounted_cumulative_gain@5(0.0): 0.6776773015027755 - mean_average_precision(0.0): 0.633078259024834
Epoch 8/30
102/102 [==============================] - 22s 216ms/step - loss: 1.0001
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6305896279538845 - normalized_discounted_cumulative_gain@5(0.0): 0.6776773015027755 - mean_average_precision(0.0): 0.633078259024834
Epoch 9/30
102/102 [==============================] - 22s 215ms/step - loss: 1.0000
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6305896279538845 - normalized_discounted_cumulative_gain@5(0.0): 0.6776773015027755 - mean_average_precision(0.0): 0.633078259024834
Epoch 10/30
102/102 [==============================] - 23s 221ms/step - loss: 1.0000
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6305896279538845 - normalized_discounted_cumulative_gain@5(0.0): 0.6776773015027755 - mean_average_precision(0.0): 0.633078259024834
Epoch 11/30
102/102 [==============================] - 22s 215ms/step - loss: 1.0000
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6305896279538845 - normalized_discounted_cumulative_gain@5(0.0): 0.6776773015027755 - mean_average_precision(0.0): 0.633078259024834
Epoch 12/30
102/102 [==============================] - 22s 218ms/step - loss: 1.0000
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6305896279538845 - normalized_discounted_cumulative_gain@5(0.0): 0.6776773015027755 - mean_average_precision(0.0): 0.633078259024834
Epoch 13/30
102/102 [==============================] - 22s 218ms/step - loss: 1.0000
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6305896279538845 - normalized_discounted_cumulative_gain@5(0.0): 0.6776773015027755 - mean_average_precision(0.0): 0.633078259024834
Epoch 14/30
102/102 [==============================] - 22s 217ms/step - loss: 1.0000
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6305896279538845 - normalized_discounted_cumulative_gain@5(0.0): 0.6776773015027755 - mean_average_precision(0.0): 0.633078259024834
Epoch 15/30
102/102 [==============================] - 23s 227ms/step - loss: 1.0000
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6305896279538845 - normalized_discounted_cumulative_gain@5(0.0): 0.6776773015027755 - mean_average_precision(0.0): 0.633078259024834
Epoch 16/30
102/102 [==============================] - 26s 258ms/step - loss: 1.0000
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6305896279538845 - normalized_discounted_cumulative_gain@5(0.0): 0.6776773015027755 - mean_average_precision(0.0): 0.633078259024834
Epoch 17/30
102/102 [==============================] - 25s 241ms/step - loss: 1.0000
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6305896279538845 - normalized_discounted_cumulative_gain@5(0.0): 0.6776773015027755 - mean_average_precision(0.0): 0.633078259024834
Epoch 18/30
102/102 [==============================] - 27s 260ms/step - loss: 1.0000
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6305896279538845 - normalized_discounted_cumulative_gain@5(0.0): 0.6776773015027755 - mean_average_precision(0.0): 0.633078259024834
Epoch 19/30
102/102 [==============================] - 25s 249ms/step - loss: 1.0000
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6305896279538845 - normalized_discounted_cumulative_gain@5(0.0): 0.6776773015027755 - mean_average_precision(0.0): 0.633078259024834
Epoch 20/30
102/102 [==============================] - 25s 245ms/step - loss: 1.0000
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6305896279538845 - normalized_discounted_cumulative_gain@5(0.0): 0.6776773015027755 - mean_average_precision(0.0): 0.633078259024834
Epoch 21/30
102/102 [==============================] - 28s 272ms/step - loss: 1.0000
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6305896279538845 - normalized_discounted_cumulative_gain@5(0.0): 0.6776773015027755 - mean_average_precision(0.0): 0.633078259024834
Epoch 22/30
102/102 [==============================] - 25s 242ms/step - loss: 1.0000
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6305896279538845 - normalized_discounted_cumulative_gain@5(0.0): 0.6776773015027755 - mean_average_precision(0.0): 0.633078259024834
Epoch 23/30
102/102 [==============================] - 26s 255ms/step - loss: 1.00001s - loss:
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6305896279538845 - normalized_discounted_cumulative_gain@5(0.0): 0.6776773015027755 - mean_average_precision(0.0): 0.633078259024834
Epoch 24/30
102/102 [==============================] - 26s 256ms/step - loss: 1.0000
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6305896279538845 - normalized_discounted_cumulative_gain@5(0.0): 0.6776773015027755 - mean_average_precision(0.0): 0.633078259024834
Epoch 25/30
102/102 [==============================] - 25s 241ms/step - loss: 1.0000
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6305896279538845 - normalized_discounted_cumulative_gain@5(0.0): 0.6776773015027755 - mean_average_precision(0.0): 0.633078259024834
Epoch 26/30
102/102 [==============================] - 28s 274ms/step - loss: 1.0000
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6305896279538845 - normalized_discounted_cumulative_gain@5(0.0): 0.6776773015027755 - mean_average_precision(0.0): 0.633078259024834
Epoch 27/30
102/102 [==============================] - 25s 245ms/step - loss: 1.0000
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6305896279538845 - normalized_discounted_cumulative_gain@5(0.0): 0.6776773015027755 - mean_average_precision(0.0): 0.633078259024834
Epoch 28/30
102/102 [==============================] - 26s 255ms/step - loss: 1.0000
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6305896279538845 - normalized_discounted_cumulative_gain@5(0.0): 0.6776773015027755 - mean_average_precision(0.0): 0.633078259024834
Epoch 29/30
102/102 [==============================] - 26s 255ms/step - loss: 1.0000
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6305896279538845 - normalized_discounted_cumulative_gain@5(0.0): 0.6776773015027755 - mean_average_precision(0.0): 0.633078259024834
Epoch 30/30
102/102 [==============================] - 24s 238ms/step - loss: 1.0000
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6305896279538845 - normalized_discounted_cumulative_gain@5(0.0): 0.6776773015027755 - mean_average_precision(0.0): 0.633078259024834
In [ ]:
Content source: faneshion/MatchZoo
Similar notebooks: