In [1]:
%run init.ipynb
Using TensorFlow backend.
/home/fanyixing/.local/python3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88
return f(*args, **kwds)
/home/fanyixing/.local/python3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88
return f(*args, **kwds)
matchzoo version 2.1.0
data loading ...
data loaded as `train_pack_raw` `dev_pack_raw` `test_pack_raw`
`ranking_task` initialized with metrics [normalized_discounted_cumulative_gain@3(0.0), normalized_discounted_cumulative_gain@5(0.0), mean_average_precision(0.0)]
loading embedding ...
embedding loaded as `glove_embedding`
In [2]:
preprocessor = mz.preprocessors.BasicPreprocessor(fixed_length_left=10, fixed_length_right=100, remove_stop_words=True)
train_pack_processed = preprocessor.fit_transform(train_pack_raw)
valid_pack_processed = preprocessor.transform(dev_pack_raw)
test_pack_processed = preprocessor.transform(test_pack_raw)
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 3277.07it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 18841/18841 [00:08<00:00, 2128.88it/s]
Processing text_right with append: 100%|██████████| 18841/18841 [00:00<00:00, 443061.44it/s]
Building FrequencyFilter from a datapack.: 100%|██████████| 18841/18841 [00:00<00:00, 76946.33it/s]
Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 61087.65it/s]
Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 357127.07it/s]
Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 382711.17it/s]
Building Vocabulary from a datapack.: 100%|██████████| 234249/234249 [00:00<00:00, 1785462.63it/s]
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 3638.16it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 18841/18841 [00:08<00:00, 2107.97it/s]
Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 80677.64it/s]
Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 144150.06it/s]
Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 97328.59it/s]
Processing length_left with len: 100%|██████████| 2118/2118 [00:00<00:00, 307271.83it/s]
Processing length_right with len: 100%|██████████| 18841/18841 [00:00<00:00, 373124.96it/s]
Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 49205.36it/s]
Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 38907.05it/s]
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 122/122 [00:00<00:00, 3540.82it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 1115/1115 [00:00<00:00, 2163.02it/s]
Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 8943.66it/s]
Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 72060.99it/s]
Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 100249.71it/s]
Processing length_left with len: 100%|██████████| 122/122 [00:00<00:00, 116323.05it/s]
Processing length_right with len: 100%|██████████| 1115/1115 [00:00<00:00, 329040.24it/s]
Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 36927.55it/s]
Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 38826.80it/s]
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 237/237 [00:00<00:00, 3748.05it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 2300/2300 [00:01<00:00, 2106.07it/s]
Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 89301.64it/s]
Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 91550.01it/s]
Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 99896.44it/s]
Processing length_left with len: 100%|██████████| 237/237 [00:00<00:00, 183879.03it/s]
Processing length_right with len: 100%|██████████| 2300/2300 [00:00<00:00, 376464.36it/s]
Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 44442.71it/s]
Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 40030.45it/s]
In [3]:
preprocessor.context
Out[3]:
{'filter_unit': <matchzoo.preprocessors.units.frequency_filter.FrequencyFilter at 0x7ff0d35a0f98>,
'vocab_unit': <matchzoo.preprocessors.units.vocabulary.Vocabulary at 0x7ff0d4f9b0f0>,
'vocab_size': 16546,
'embedding_input_dim': 16546,
'input_shapes': [(10,), (100,)]}
In [4]:
model = mz.models.ArcI()
model.params.update(preprocessor.context)
model.params['task'] = ranking_task
model.params['embedding_output_dim'] = glove_embedding.output_dim
model.params['num_blocks'] = 1
model.params['left_filters'] = [128]
model.params['left_kernel_sizes'] = [3]
model.params['left_pool_sizes'] = [4]
model.params['right_filters'] = [128]
model.params['right_kernel_sizes'] = [3]
model.params['right_pool_sizes'] = [4]
model.params['conv_activation_func']= 'relu'
model.params['mlp_num_layers'] = 1
model.params['mlp_num_units'] = 100
model.params['mlp_num_fan_out'] = 1
model.params['mlp_activation_func'] = 'relu'
model.params['dropout_rate'] = 0.9
model.params['optimizer'] = 'adadelta'
model.guess_and_fill_missing_params()
model.build()
model.compile()
model.backend.summary()
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
text_left (InputLayer) (None, 10) 0
__________________________________________________________________________________________________
text_right (InputLayer) (None, 100) 0
__________________________________________________________________________________________________
embedding (Embedding) multiple 4963800 text_left[0][0]
text_right[0][0]
__________________________________________________________________________________________________
conv1d_1 (Conv1D) (None, 10, 128) 115328 embedding[0][0]
__________________________________________________________________________________________________
conv1d_2 (Conv1D) (None, 100, 128) 115328 embedding[1][0]
__________________________________________________________________________________________________
max_pooling1d_1 (MaxPooling1D) (None, 2, 128) 0 conv1d_1[0][0]
__________________________________________________________________________________________________
max_pooling1d_2 (MaxPooling1D) (None, 25, 128) 0 conv1d_2[0][0]
__________________________________________________________________________________________________
flatten_1 (Flatten) (None, 256) 0 max_pooling1d_1[0][0]
__________________________________________________________________________________________________
flatten_2 (Flatten) (None, 3200) 0 max_pooling1d_2[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate) (None, 3456) 0 flatten_1[0][0]
flatten_2[0][0]
__________________________________________________________________________________________________
dropout_1 (Dropout) (None, 3456) 0 concatenate_1[0][0]
__________________________________________________________________________________________________
dense_1 (Dense) (None, 100) 345700 dropout_1[0][0]
__________________________________________________________________________________________________
dense_2 (Dense) (None, 1) 101 dense_1[0][0]
__________________________________________________________________________________________________
dense_3 (Dense) (None, 1) 2 dense_2[0][0]
==================================================================================================
Total params: 5,540,259
Trainable params: 5,540,259
Non-trainable params: 0
__________________________________________________________________________________________________
In [5]:
embedding_matrix = glove_embedding.build_matrix(preprocessor.context['vocab_unit'].state['term_index'])
model.load_embedding_matrix(embedding_matrix)
In [6]:
pred_x, pred_y = test_pack_processed.unpack()
evaluate = mz.callbacks.EvaluateAllMetrics(model, x=pred_x, y=pred_y, batch_size=len(pred_y))
In [7]:
train_generator = mz.DataGenerator(
train_pack_processed,
mode='pair',
num_dup=2,
num_neg=1,
batch_size=20
)
print('num batches:', len(train_generator))
num batches: 102
In [8]:
history = model.fit_generator(train_generator, epochs=30, callbacks=[evaluate], workers=30, use_multiprocessing=True)
Epoch 1/30
102/102 [==============================] - 12s 113ms/step - loss: 0.9915
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6305896279538845 - normalized_discounted_cumulative_gain@5(0.0): 0.6776773015027755 - mean_average_precision(0.0): 0.633078259024834
Epoch 2/30
102/102 [==============================] - 16s 153ms/step - loss: 0.9609
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6414052424694674 - normalized_discounted_cumulative_gain@5(0.0): 0.6896110630787813 - mean_average_precision(0.0): 0.655464619746667
Epoch 3/30
102/102 [==============================] - 15s 147ms/step - loss: 0.9213
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.609449432076622 - normalized_discounted_cumulative_gain@5(0.0): 0.6535863921295467 - mean_average_precision(0.0): 0.6256401326730503
Epoch 4/30
102/102 [==============================] - 14s 139ms/step - loss: 0.8644
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5668504718763049 - normalized_discounted_cumulative_gain@5(0.0): 0.6184823173536319 - mean_average_precision(0.0): 0.5887412898235421
Epoch 5/30
102/102 [==============================] - 14s 140ms/step - loss: 0.8046
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5665334463767309 - normalized_discounted_cumulative_gain@5(0.0): 0.6176058113896864 - mean_average_precision(0.0): 0.5852152847278305
Epoch 6/30
102/102 [==============================] - 15s 148ms/step - loss: 0.8133
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5653296421157549 - normalized_discounted_cumulative_gain@5(0.0): 0.6098561773567537 - mean_average_precision(0.0): 0.5787384301794942
Epoch 7/30
102/102 [==============================] - 15s 144ms/step - loss: 0.7223
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5115731066519837 - normalized_discounted_cumulative_gain@5(0.0): 0.5696043849325363 - mean_average_precision(0.0): 0.5309471148833632
Epoch 8/30
102/102 [==============================] - 13s 129ms/step - loss: 0.7452
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5654535528839739 - normalized_discounted_cumulative_gain@5(0.0): 0.6150188550188977 - mean_average_precision(0.0): 0.5855908181807582
Epoch 9/30
102/102 [==============================] - 15s 149ms/step - loss: 0.6732
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.537504491728362 - normalized_discounted_cumulative_gain@5(0.0): 0.5888879094450309 - mean_average_precision(0.0): 0.556922967842061
Epoch 10/30
102/102 [==============================] - 15s 146ms/step - loss: 0.6431
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5294134618498313 - normalized_discounted_cumulative_gain@5(0.0): 0.5900024900110544 - mean_average_precision(0.0): 0.5542640985018126
Epoch 11/30
102/102 [==============================] - 14s 134ms/step - loss: 0.5859
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5520633189989336 - normalized_discounted_cumulative_gain@5(0.0): 0.6108468663800319 - mean_average_precision(0.0): 0.5791519476377355
Epoch 12/30
102/102 [==============================] - 15s 146ms/step - loss: 0.5602
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.534624128426663 - normalized_discounted_cumulative_gain@5(0.0): 0.5889541538361915 - mean_average_precision(0.0): 0.551507001163675
Epoch 13/30
102/102 [==============================] - 15s 144ms/step - loss: 0.5450
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5627233027710186 - normalized_discounted_cumulative_gain@5(0.0): 0.6108236823978452 - mean_average_precision(0.0): 0.584069096207155
Epoch 14/30
102/102 [==============================] - 15s 144ms/step - loss: 0.5581
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5513529195409487 - normalized_discounted_cumulative_gain@5(0.0): 0.5999350241763287 - mean_average_precision(0.0): 0.568106916524638
Epoch 15/30
102/102 [==============================] - 16s 156ms/step - loss: 0.4980
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5546799339708178 - normalized_discounted_cumulative_gain@5(0.0): 0.6152275605059057 - mean_average_precision(0.0): 0.5795762526981569
Epoch 16/30
102/102 [==============================] - 14s 141ms/step - loss: 0.5071
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5600349988760164 - normalized_discounted_cumulative_gain@5(0.0): 0.6072060773562872 - mean_average_precision(0.0): 0.5774453145256668
Epoch 17/30
102/102 [==============================] - 14s 140ms/step - loss: 0.4518
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.554905029677223 - normalized_discounted_cumulative_gain@5(0.0): 0.6021378901796073 - mean_average_precision(0.0): 0.5722005742097239
Epoch 18/30
102/102 [==============================] - 15s 145ms/step - loss: 0.4292
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5538251855689398 - normalized_discounted_cumulative_gain@5(0.0): 0.6006882253891397 - mean_average_precision(0.0): 0.5649684864479169
Epoch 19/30
102/102 [==============================] - 12s 116ms/step - loss: 0.4222
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5502126537672055 - normalized_discounted_cumulative_gain@5(0.0): 0.5933742887299561 - mean_average_precision(0.0): 0.5631647115191418
Epoch 20/30
102/102 [==============================] - 14s 142ms/step - loss: 0.3871
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.545929755381746 - normalized_discounted_cumulative_gain@5(0.0): 0.5965961908898312 - mean_average_precision(0.0): 0.5620997683843287
Epoch 21/30
102/102 [==============================] - 15s 145ms/step - loss: 0.3485
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.545785357053379 - normalized_discounted_cumulative_gain@5(0.0): 0.6002958901867365 - mean_average_precision(0.0): 0.5651984875273516
Epoch 22/30
102/102 [==============================] - 15s 152ms/step - loss: 0.3665
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.554029843099671 - normalized_discounted_cumulative_gain@5(0.0): 0.6048247555957358 - mean_average_precision(0.0): 0.5708237928362012
Epoch 23/30
102/102 [==============================] - 13s 128ms/step - loss: 0.3638
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5491694007852175 - normalized_discounted_cumulative_gain@5(0.0): 0.6068284143675057 - mean_average_precision(0.0): 0.568656261259232
Epoch 24/30
102/102 [==============================] - 13s 132ms/step - loss: 0.3220
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5501388396528246 - normalized_discounted_cumulative_gain@5(0.0): 0.6082083078502355 - mean_average_precision(0.0): 0.5698885871168076
Epoch 25/30
102/102 [==============================] - 14s 134ms/step - loss: 0.3111
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5529613793276678 - normalized_discounted_cumulative_gain@5(0.0): 0.6087770320216022 - mean_average_precision(0.0): 0.5719623784893997
Epoch 26/30
102/102 [==============================] - 15s 143ms/step - loss: 0.2875
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5535491988838098 - normalized_discounted_cumulative_gain@5(0.0): 0.607511521411928 - mean_average_precision(0.0): 0.562998351131385
Epoch 27/30
102/102 [==============================] - 14s 133ms/step - loss: 0.2849
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5469268899993501 - normalized_discounted_cumulative_gain@5(0.0): 0.6037117043525182 - mean_average_precision(0.0): 0.5606121264001305
Epoch 28/30
102/102 [==============================] - 14s 140ms/step - loss: 0.2623
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5550211003938021 - normalized_discounted_cumulative_gain@5(0.0): 0.6047800527522788 - mean_average_precision(0.0): 0.5628610719197386
Epoch 29/30
102/102 [==============================] - 14s 141ms/step - loss: 0.2663
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5457115429389982 - normalized_discounted_cumulative_gain@5(0.0): 0.6059713781830973 - mean_average_precision(0.0): 0.5609192818766827
Epoch 30/30
102/102 [==============================] - 15s 145ms/step - loss: 0.2661
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5475428253197103 - normalized_discounted_cumulative_gain@5(0.0): 0.610957175825584 - mean_average_precision(0.0): 0.5696311917780938
In [ ]:
Content source: faneshion/MatchZoo
Similar notebooks: