In [1]:
%run init.ipynb
Using TensorFlow backend.
/home/fanyixing/.local/python3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88
return f(*args, **kwds)
/home/fanyixing/.local/python3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88
return f(*args, **kwds)
matchzoo version 2.1.0
data loading ...
data loaded as `train_pack_raw` `dev_pack_raw` `test_pack_raw`
`ranking_task` initialized with metrics [normalized_discounted_cumulative_gain@3(0.0), normalized_discounted_cumulative_gain@5(0.0), mean_average_precision(0.0)]
loading embedding ...
embedding loaded as `glove_embedding`
In [2]:
preprocessor = mz.preprocessors.BasicPreprocessor(fixed_length_left=10, fixed_length_right=100, remove_stop_words=True)
train_pack_processed = preprocessor.fit_transform(train_pack_raw)
valid_pack_processed = preprocessor.transform(dev_pack_raw)
test_pack_processed = preprocessor.transform(test_pack_raw)
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 3594.37it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 18841/18841 [00:11<00:00, 1586.27it/s]
Processing text_right with append: 100%|██████████| 18841/18841 [00:00<00:00, 331119.09it/s]
Building FrequencyFilter from a datapack.: 100%|██████████| 18841/18841 [00:00<00:00, 53611.85it/s]
Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 42001.20it/s]
Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 210146.80it/s]
Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 246073.81it/s]
Building Vocabulary from a datapack.: 100%|██████████| 234249/234249 [00:00<00:00, 1114012.95it/s]
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 2499.93it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 18841/18841 [00:09<00:00, 2037.29it/s]
Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 87829.63it/s]
Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 151972.22it/s]
Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 100885.06it/s]
Processing length_left with len: 100%|██████████| 2118/2118 [00:00<00:00, 342557.20it/s]
Processing length_right with len: 100%|██████████| 18841/18841 [00:00<00:00, 378489.78it/s]
Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 47551.82it/s]
Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 36324.09it/s]
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 122/122 [00:00<00:00, 3139.22it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 1115/1115 [00:00<00:00, 1648.55it/s]
Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 38860.67it/s]
Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 60122.79it/s]
Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 6474.05it/s]
Processing length_left with len: 100%|██████████| 122/122 [00:00<00:00, 122976.47it/s]
Processing length_right with len: 100%|██████████| 1115/1115 [00:00<00:00, 273009.28it/s]
Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 29764.14it/s]
Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 25098.34it/s]
Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 237/237 [00:00<00:00, 2521.45it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 2300/2300 [00:01<00:00, 1527.39it/s]
Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 73862.60it/s]
Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 15517.48it/s]
Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 54723.62it/s]
Processing length_left with len: 100%|██████████| 237/237 [00:00<00:00, 135539.96it/s]
Processing length_right with len: 100%|██████████| 2300/2300 [00:00<00:00, 281653.07it/s]
Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 27051.19it/s]
Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 24465.45it/s]
In [3]:
model = mz.models.DUET()
model.params.update(preprocessor.context)
model.params['task'] = ranking_task
model.params['embedding_output_dim'] = 300
model.params['lm_filters'] = 32
model.params['lm_hidden_sizes'] = [32]
model.params['dm_filters'] = 32
model.params['dm_kernel_size'] = 3
model.params['dm_d_mpool'] = 4
model.params['dm_hidden_sizes'] = [32]
model.params['dropout_rate'] = 0.5
optimizer = keras.optimizers.Adamax(lr=0.0001, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0)
model.params['optimizer'] = 'adagrad'
model.guess_and_fill_missing_params()
model.build()
model.compile()
model.backend.summary()
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
text_left (InputLayer) (None, 10) 0
__________________________________________________________________________________________________
text_right (InputLayer) (None, 100) 0
__________________________________________________________________________________________________
embedding (Embedding) multiple 4963800 text_left[0][0]
text_right[0][0]
__________________________________________________________________________________________________
conv1d_2 (Conv1D) (None, 10, 32) 28832 embedding[0][0]
__________________________________________________________________________________________________
dropout_3 (Dropout) (None, 10, 32) 0 conv1d_2[0][0]
__________________________________________________________________________________________________
conv1d_3 (Conv1D) (None, 100, 32) 28832 embedding[1][0]
__________________________________________________________________________________________________
max_pooling1d_1 (MaxPooling1D) (None, 1, 32) 0 dropout_3[0][0]
__________________________________________________________________________________________________
dropout_4 (Dropout) (None, 100, 32) 0 conv1d_3[0][0]
__________________________________________________________________________________________________
reshape_2 (Reshape) (None, 32) 0 max_pooling1d_1[0][0]
__________________________________________________________________________________________________
max_pooling1d_2 (MaxPooling1D) (None, 25, 32) 0 dropout_4[0][0]
__________________________________________________________________________________________________
lambda_1 (Lambda) (None, 10, 100) 0 text_left[0][0]
text_right[0][0]
__________________________________________________________________________________________________
dense_3 (Dense) (None, 32) 1056 reshape_2[0][0]
__________________________________________________________________________________________________
conv1d_4 (Conv1D) (None, 25, 32) 1056 max_pooling1d_2[0][0]
__________________________________________________________________________________________________
conv1d_1 (Conv1D) (None, 10, 32) 320032 lambda_1[0][0]
__________________________________________________________________________________________________
lambda_2 (Lambda) (None, 1, 32) 0 dense_3[0][0]
__________________________________________________________________________________________________
dropout_5 (Dropout) (None, 25, 32) 0 conv1d_4[0][0]
__________________________________________________________________________________________________
dropout_1 (Dropout) (None, 10, 32) 0 conv1d_1[0][0]
__________________________________________________________________________________________________
lambda_3 (Lambda) (None, 25, 32) 0 lambda_2[0][0]
dropout_5[0][0]
__________________________________________________________________________________________________
reshape_1 (Reshape) (None, 320) 0 dropout_1[0][0]
__________________________________________________________________________________________________
reshape_3 (Reshape) (None, 800) 0 lambda_3[0][0]
__________________________________________________________________________________________________
dense_1 (Dense) (None, 32) 10272 reshape_1[0][0]
__________________________________________________________________________________________________
dense_4 (Dense) (None, 32) 25632 reshape_3[0][0]
__________________________________________________________________________________________________
dropout_2 (Dropout) (None, 32) 0 dense_1[0][0]
__________________________________________________________________________________________________
dropout_6 (Dropout) (None, 32) 0 dense_4[0][0]
__________________________________________________________________________________________________
dense_2 (Dense) (None, 1) 33 dropout_2[0][0]
__________________________________________________________________________________________________
dense_5 (Dense) (None, 1) 33 dropout_6[0][0]
__________________________________________________________________________________________________
add_1 (Add) (None, 1) 0 dense_2[0][0]
dense_5[0][0]
__________________________________________________________________________________________________
dense_6 (Dense) (None, 1) 2 add_1[0][0]
==================================================================================================
Total params: 5,379,580
Trainable params: 5,379,580
Non-trainable params: 0
__________________________________________________________________________________________________
In [4]:
embedding_matrix = glove_embedding.build_matrix(preprocessor.context['vocab_unit'].state['term_index'])
model.load_embedding_matrix(embedding_matrix)
In [5]:
pred_x, pred_y = test_pack_processed[:].unpack()
evaluate = mz.callbacks.EvaluateAllMetrics(model, x=pred_x, y=pred_y, batch_size=len(pred_y))
In [6]:
train_generator = mz.DataGenerator(
train_pack_processed,
mode='pair',
num_dup=2,
num_neg=1,
batch_size=20
)
print('num batches:', len(train_generator))
num batches: 102
In [7]:
history = model.fit_generator(train_generator, epochs=30, callbacks=[evaluate], workers=30, use_multiprocessing=True)
Epoch 1/30
102/102 [==============================] - 10s 102ms/step - loss: 1.3443
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5434086104876706 - normalized_discounted_cumulative_gain@5(0.0): 0.5956448102614595 - mean_average_precision(0.0): 0.5563239011406736
Epoch 2/30
102/102 [==============================] - 19s 183ms/step - loss: 0.7973
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5409544601384676 - normalized_discounted_cumulative_gain@5(0.0): 0.5963200911770458 - mean_average_precision(0.0): 0.5538297188377527
Epoch 3/30
102/102 [==============================] - 18s 172ms/step - loss: 0.7230
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5349731349012871 - normalized_discounted_cumulative_gain@5(0.0): 0.5990249616772886 - mean_average_precision(0.0): 0.5591250854644376
Epoch 4/30
102/102 [==============================] - 19s 189ms/step - loss: 0.5881
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5287827044452265 - normalized_discounted_cumulative_gain@5(0.0): 0.5908037158184187 - mean_average_precision(0.0): 0.5466419807557238
Epoch 5/30
102/102 [==============================] - 20s 196ms/step - loss: 0.5804
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5240370019809759 - normalized_discounted_cumulative_gain@5(0.0): 0.6006478587685735 - mean_average_precision(0.0): 0.5450772528295846
Epoch 6/30
102/102 [==============================] - 18s 172ms/step - loss: 0.5368
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5226797909367127 - normalized_discounted_cumulative_gain@5(0.0): 0.5893317663937131 - mean_average_precision(0.0): 0.5410155727818953
Epoch 7/30
102/102 [==============================] - 17s 169ms/step - loss: 0.4807
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4951695408576308 - normalized_discounted_cumulative_gain@5(0.0): 0.5706620204391553 - mean_average_precision(0.0): 0.5178932175787853
Epoch 8/30
102/102 [==============================] - 20s 197ms/step - loss: 0.3995
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5056429827777733 - normalized_discounted_cumulative_gain@5(0.0): 0.5743399763446194 - mean_average_precision(0.0): 0.518527984938907
Epoch 9/30
102/102 [==============================] - 19s 182ms/step - loss: 0.3581
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5029149072405547 - normalized_discounted_cumulative_gain@5(0.0): 0.5728075884356819 - mean_average_precision(0.0): 0.5230789712170018
Epoch 10/30
102/102 [==============================] - 17s 168ms/step - loss: 0.3390
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5221954439731158 - normalized_discounted_cumulative_gain@5(0.0): 0.583142417896642 - mean_average_precision(0.0): 0.5377685049399303
Epoch 11/30
102/102 [==============================] - 18s 180ms/step - loss: 0.2994
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5136931363049559 - normalized_discounted_cumulative_gain@5(0.0): 0.5779734005370663 - mean_average_precision(0.0): 0.5308713805567863
Epoch 12/30
102/102 [==============================] - 20s 197ms/step - loss: 0.2930
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5144491847286299 - normalized_discounted_cumulative_gain@5(0.0): 0.5764922861029833 - mean_average_precision(0.0): 0.5238341669133789
Epoch 13/30
102/102 [==============================] - 18s 172ms/step - loss: 0.2568
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5001641695674516 - normalized_discounted_cumulative_gain@5(0.0): 0.5678059065572373 - mean_average_precision(0.0): 0.5163413086929762
Epoch 14/30
102/102 [==============================] - 17s 167ms/step - loss: 0.2333
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.49355785117051687 - normalized_discounted_cumulative_gain@5(0.0): 0.5649387413500839 - mean_average_precision(0.0): 0.5107034360845545
Epoch 15/30
102/102 [==============================] - 20s 192ms/step - loss: 0.2004
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5048462698503208 - normalized_discounted_cumulative_gain@5(0.0): 0.5777265739049287 - mean_average_precision(0.0): 0.5251572469500373
Epoch 16/30
102/102 [==============================] - 19s 187ms/step - loss: 0.2013
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.49244157839903463 - normalized_discounted_cumulative_gain@5(0.0): 0.5666747321525085 - mean_average_precision(0.0): 0.5174358597173779
Epoch 17/30
102/102 [==============================] - 17s 171ms/step - loss: 0.1627
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5008963874519605 - normalized_discounted_cumulative_gain@5(0.0): 0.5693132311623968 - mean_average_precision(0.0): 0.519828449775054
Epoch 18/30
102/102 [==============================] - 18s 180ms/step - loss: 0.1676
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4964575973232318 - normalized_discounted_cumulative_gain@5(0.0): 0.5675852335530719 - mean_average_precision(0.0): 0.5214283696480859
Epoch 19/30
102/102 [==============================] - 20s 195ms/step - loss: 0.1422
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5012929419262925 - normalized_discounted_cumulative_gain@5(0.0): 0.5641706237691956 - mean_average_precision(0.0): 0.5234524442356135
Epoch 20/30
102/102 [==============================] - 17s 167ms/step - loss: 0.1358
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5069445941553911 - normalized_discounted_cumulative_gain@5(0.0): 0.5749811028788108 - mean_average_precision(0.0): 0.5316737183006729
Epoch 21/30
102/102 [==============================] - 17s 170ms/step - loss: 0.1279
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.49648582616607223 - normalized_discounted_cumulative_gain@5(0.0): 0.5583586978837978 - mean_average_precision(0.0): 0.5178131860833248
Epoch 22/30
102/102 [==============================] - 19s 187ms/step - loss: 0.1256
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4975374414315149 - normalized_discounted_cumulative_gain@5(0.0): 0.5594682616966064 - mean_average_precision(0.0): 0.5193958200237347
Epoch 23/30
102/102 [==============================] - 15s 151ms/step - loss: 0.1248
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4982876125317161 - normalized_discounted_cumulative_gain@5(0.0): 0.5591300527258516 - mean_average_precision(0.0): 0.5188557732395174
Epoch 24/30
102/102 [==============================] - 15s 152ms/step - loss: 0.1275
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4927392704320665 - normalized_discounted_cumulative_gain@5(0.0): 0.5574523653730236 - mean_average_precision(0.0): 0.5155605033553068
Epoch 25/30
102/102 [==============================] - 15s 146ms/step - loss: 0.1103
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.49879007517410123 - normalized_discounted_cumulative_gain@5(0.0): 0.5586164397323847 - mean_average_precision(0.0): 0.5160934561034494
Epoch 26/30
102/102 [==============================] - 15s 149ms/step - loss: 0.0892
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4958789980695799 - normalized_discounted_cumulative_gain@5(0.0): 0.5614716756878623 - mean_average_precision(0.0): 0.5169777693899112
Epoch 27/30
102/102 [==============================] - 15s 148ms/step - loss: 0.0952
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4954856734956429 - normalized_discounted_cumulative_gain@5(0.0): 0.5589290693398398 - mean_average_precision(0.0): 0.5123821730435948
Epoch 28/30
102/102 [==============================] - 15s 150ms/step - loss: 0.0918
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4942614804434666 - normalized_discounted_cumulative_gain@5(0.0): 0.5593891548699772 - mean_average_precision(0.0): 0.5108165708897574
Epoch 29/30
102/102 [==============================] - 15s 150ms/step - loss: 0.0925
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.49400667687442384 - normalized_discounted_cumulative_gain@5(0.0): 0.5621486460723 - mean_average_precision(0.0): 0.5123339065745354
Epoch 30/30
102/102 [==============================] - 15s 150ms/step - loss: 0.0828
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4966496567299287 - normalized_discounted_cumulative_gain@5(0.0): 0.5533107361861653 - mean_average_precision(0.0): 0.510489040605752
In [ ]:
Content source: faneshion/MatchZoo
Similar notebooks: