In [1]:
# code meaning predictor using advance NN techs

In [54]:
'''Trains a LSTM on the IMDB sentiment classification task.

The dataset is actually too small for LSTM to be of any advantage
compared to simpler, much faster methods such as TF-IDF+LogReg.

Notes:

- RNNs are tricky. Choice of batch size is important,
choice of loss and optimizer is critical, etc.
Some configurations won't converge.

- LSTM loss decrease patterns during training can be quite different
from what you see with CNNs/MLPs/etc.
'''
from __future__ import print_function
import numpy as np
np.random.seed(1337)  # for reproducibility

from keras.preprocessing import sequence
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation
from keras.layers.embeddings import Embedding
from keras.layers.recurrent import LSTM, SimpleRNN, GRU
from keras.datasets import imdb

n_topic = 40
maxlen = 200  # cut texts after this number of words (among top max_features most common words)
batch_size = 10

print(len(X_train), 'train sequences')
print(len(X_test), 'test sequences')

print('Pad sequences (samples x time)')
X_train = sequence.pad_sequences(X_train, maxlen=maxlen)
X_test = sequence.pad_sequences(X_test, maxlen=maxlen)
print('X_train shape:', X_train.shape)
print('X_test shape:', X_test.shape)

print('Build model...')
model = Sequential()
model.add(GRU(64, dropout_W=0.25, dropout_U=0.25, input_dim=8))  # try using a GRU instead, for fun
model.add(Dense(128, activation='tanh'))
model.add(Dropout(0.5))
model.add(Dense(n_topic))
model.add(Activation('softmax'))

# try using different optimizers and different optimizer configs
model.compile(loss='categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

print('Train...')
print(X_train.shape)
print(Y_train.shape)
hist = model.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=200,
          validation_data=(X_test, Y_test))
score, acc = model.evaluate(X_test, Y_test,
                            batch_size=batch_size)
print('Test score:', score)
print('Test accuracy:', acc)


3183 train sequences
796 test sequences
Pad sequences (samples x time)
X_train shape: (3183, 200, 8)
X_test shape: (796, 200, 8)
Build model...
Train...
(3183, 200, 8)
(3183, 40)
Train on 3183 samples, validate on 796 samples
Epoch 1/200
3183/3183 [==============================] - 141s - loss: 2.6801 - acc: 0.0393 - val_loss: 2.6107 - val_acc: 0.0678
Epoch 2/200
3183/3183 [==============================] - 141s - loss: 2.6505 - acc: 0.0537 - val_loss: 2.6013 - val_acc: 0.0854
Epoch 3/200
3183/3183 [==============================] - 141s - loss: 2.6372 - acc: 0.0631 - val_loss: 2.5929 - val_acc: 0.0766
Epoch 4/200
3183/3183 [==============================] - 141s - loss: 2.6286 - acc: 0.0650 - val_loss: 2.5977 - val_acc: 0.0892
Epoch 5/200
3183/3183 [==============================] - 141s - loss: 2.6175 - acc: 0.0688 - val_loss: 2.5951 - val_acc: 0.0741
Epoch 6/200
3183/3183 [==============================] - 141s - loss: 2.6106 - acc: 0.0751 - val_loss: 2.5873 - val_acc: 0.0678
Epoch 7/200
3183/3183 [==============================] - 141s - loss: 2.6069 - acc: 0.0741 - val_loss: 2.5853 - val_acc: 0.0729
Epoch 8/200
3183/3183 [==============================] - 140s - loss: 2.5934 - acc: 0.0864 - val_loss: 2.5931 - val_acc: 0.0779
Epoch 9/200
3183/3183 [==============================] - 140s - loss: 2.5919 - acc: 0.0858 - val_loss: 2.5865 - val_acc: 0.0854
Epoch 10/200
3183/3183 [==============================] - 140s - loss: 2.5821 - acc: 0.0861 - val_loss: 2.5887 - val_acc: 0.0779
Epoch 11/200
3183/3183 [==============================] - 141s - loss: 2.5816 - acc: 0.0851 - val_loss: 2.5866 - val_acc: 0.0842
Epoch 12/200
3183/3183 [==============================] - 140s - loss: 2.5745 - acc: 0.0873 - val_loss: 2.5849 - val_acc: 0.0766
Epoch 13/200
3183/3183 [==============================] - 140s - loss: 2.5688 - acc: 0.0921 - val_loss: 2.5920 - val_acc: 0.0804
Epoch 14/200
3183/3183 [==============================] - 140s - loss: 2.5627 - acc: 0.0961 - val_loss: 2.5866 - val_acc: 0.0905
Epoch 15/200
3183/3183 [==============================] - 140s - loss: 2.5508 - acc: 0.0980 - val_loss: 2.5837 - val_acc: 0.0842
Epoch 16/200
3183/3183 [==============================] - 140s - loss: 2.5551 - acc: 0.0936 - val_loss: 2.5789 - val_acc: 0.0854
Epoch 17/200
3183/3183 [==============================] - 140s - loss: 2.5558 - acc: 0.0927 - val_loss: 2.5798 - val_acc: 0.0804
Epoch 18/200
3183/3183 [==============================] - 140s - loss: 2.5462 - acc: 0.0983 - val_loss: 2.5827 - val_acc: 0.0955
Epoch 19/200
3183/3183 [==============================] - 141s - loss: 2.5403 - acc: 0.1049 - val_loss: 2.5834 - val_acc: 0.0879
Epoch 20/200
3183/3183 [==============================] - 140s - loss: 2.5344 - acc: 0.1087 - val_loss: 2.5816 - val_acc: 0.0905
Epoch 21/200
3183/3183 [==============================] - 140s - loss: 2.5313 - acc: 0.1074 - val_loss: 2.5821 - val_acc: 0.0854
Epoch 22/200
3183/3183 [==============================] - 140s - loss: 2.5313 - acc: 0.1024 - val_loss: 2.5743 - val_acc: 0.0879
Epoch 23/200
3183/3183 [==============================] - 140s - loss: 2.5176 - acc: 0.1140 - val_loss: 2.5800 - val_acc: 0.0867
Epoch 24/200
3183/3183 [==============================] - 140s - loss: 2.5249 - acc: 0.1090 - val_loss: 2.5761 - val_acc: 0.0955
Epoch 25/200
3183/3183 [==============================] - 140s - loss: 2.5161 - acc: 0.1103 - val_loss: 2.5772 - val_acc: 0.0980
Epoch 26/200
3183/3183 [==============================] - 140s - loss: 2.5120 - acc: 0.1172 - val_loss: 2.5800 - val_acc: 0.0905
Epoch 27/200
3183/3183 [==============================] - 140s - loss: 2.5105 - acc: 0.1200 - val_loss: 2.5768 - val_acc: 0.0992
Epoch 28/200
3183/3183 [==============================] - 140s - loss: 2.5003 - acc: 0.1250 - val_loss: 2.5722 - val_acc: 0.0917
Epoch 29/200
3183/3183 [==============================] - 140s - loss: 2.5057 - acc: 0.1159 - val_loss: 2.5772 - val_acc: 0.0930
Epoch 30/200
3183/3183 [==============================] - 140s - loss: 2.4967 - acc: 0.1244 - val_loss: 2.5730 - val_acc: 0.0879
Epoch 31/200
3183/3183 [==============================] - 140s - loss: 2.5052 - acc: 0.1169 - val_loss: 2.5716 - val_acc: 0.0905
Epoch 32/200
3183/3183 [==============================] - 140s - loss: 2.4962 - acc: 0.1206 - val_loss: 2.5785 - val_acc: 0.0980
Epoch 33/200
3183/3183 [==============================] - 141s - loss: 2.4895 - acc: 0.1288 - val_loss: 2.5745 - val_acc: 0.1005
Epoch 34/200
3183/3183 [==============================] - 141s - loss: 2.4858 - acc: 0.1213 - val_loss: 2.5710 - val_acc: 0.1093
Epoch 35/200
3183/3183 [==============================] - 140s - loss: 2.4852 - acc: 0.1254 - val_loss: 2.5663 - val_acc: 0.1018
Epoch 36/200
3183/3183 [==============================] - 140s - loss: 2.4799 - acc: 0.1294 - val_loss: 2.5705 - val_acc: 0.1043
Epoch 37/200
3183/3183 [==============================] - 140s - loss: 2.4752 - acc: 0.1320 - val_loss: 2.5699 - val_acc: 0.1055
Epoch 38/200
3183/3183 [==============================] - 140s - loss: 2.4771 - acc: 0.1298 - val_loss: 2.5664 - val_acc: 0.1055
Epoch 39/200
3183/3183 [==============================] - 140s - loss: 2.4716 - acc: 0.1310 - val_loss: 2.5733 - val_acc: 0.1030
Epoch 40/200
3183/3183 [==============================] - 140s - loss: 2.4687 - acc: 0.1304 - val_loss: 2.5634 - val_acc: 0.1005
Epoch 41/200
3183/3183 [==============================] - 140s - loss: 2.4634 - acc: 0.1411 - val_loss: 2.5678 - val_acc: 0.1043
Epoch 42/200
3183/3183 [==============================] - 140s - loss: 2.4638 - acc: 0.1282 - val_loss: 2.5600 - val_acc: 0.1055
Epoch 43/200
3183/3183 [==============================] - 140s - loss: 2.4676 - acc: 0.1392 - val_loss: 2.5707 - val_acc: 0.0955
Epoch 44/200
3183/3183 [==============================] - 141s - loss: 2.4472 - acc: 0.1458 - val_loss: 2.5650 - val_acc: 0.1030
Epoch 45/200
3183/3183 [==============================] - 141s - loss: 2.4467 - acc: 0.1379 - val_loss: 2.5683 - val_acc: 0.1043
Epoch 46/200
3183/3183 [==============================] - 140s - loss: 2.4565 - acc: 0.1442 - val_loss: 2.5645 - val_acc: 0.1043
Epoch 47/200
3183/3183 [==============================] - 140s - loss: 2.4557 - acc: 0.1382 - val_loss: 2.5616 - val_acc: 0.1043
Epoch 48/200
3183/3183 [==============================] - 140s - loss: 2.4511 - acc: 0.1367 - val_loss: 2.5637 - val_acc: 0.1043
Epoch 49/200
3183/3183 [==============================] - 140s - loss: 2.4515 - acc: 0.1486 - val_loss: 2.5645 - val_acc: 0.1118
Epoch 50/200
3183/3183 [==============================] - 140s - loss: 2.4374 - acc: 0.1433 - val_loss: 2.5649 - val_acc: 0.1106
Epoch 51/200
3183/3183 [==============================] - 141s - loss: 2.4502 - acc: 0.1445 - val_loss: 2.5700 - val_acc: 0.1068
Epoch 52/200
3183/3183 [==============================] - 141s - loss: 2.4401 - acc: 0.1420 - val_loss: 2.5692 - val_acc: 0.1080
Epoch 53/200
3183/3183 [==============================] - 141s - loss: 2.4486 - acc: 0.1464 - val_loss: 2.5716 - val_acc: 0.0980
Epoch 54/200
3183/3183 [==============================] - 141s - loss: 2.4346 - acc: 0.1565 - val_loss: 2.5639 - val_acc: 0.1043
Epoch 55/200
3183/3183 [==============================] - 141s - loss: 2.4456 - acc: 0.1404 - val_loss: 2.5631 - val_acc: 0.1118
Epoch 56/200
3183/3183 [==============================] - 141s - loss: 2.4375 - acc: 0.1495 - val_loss: 2.5629 - val_acc: 0.1055
Epoch 57/200
3183/3183 [==============================] - 141s - loss: 2.4442 - acc: 0.1492 - val_loss: 2.5605 - val_acc: 0.1030
Epoch 58/200
3183/3183 [==============================] - 140s - loss: 2.4298 - acc: 0.1558 - val_loss: 2.5594 - val_acc: 0.1106
Epoch 59/200
3183/3183 [==============================] - 140s - loss: 2.4351 - acc: 0.1417 - val_loss: 2.5591 - val_acc: 0.1080
Epoch 60/200
3183/3183 [==============================] - 140s - loss: 2.4323 - acc: 0.1473 - val_loss: 2.5594 - val_acc: 0.1131
Epoch 61/200
3183/3183 [==============================] - 140s - loss: 2.4322 - acc: 0.1489 - val_loss: 2.5614 - val_acc: 0.1068
Epoch 62/200
3183/3183 [==============================] - 140s - loss: 2.4246 - acc: 0.1577 - val_loss: 2.5612 - val_acc: 0.1093
Epoch 63/200
3183/3183 [==============================] - 140s - loss: 2.4235 - acc: 0.1596 - val_loss: 2.5650 - val_acc: 0.1156
Epoch 64/200
3183/3183 [==============================] - 141s - loss: 2.4231 - acc: 0.1577 - val_loss: 2.5605 - val_acc: 0.1143
Epoch 65/200
3183/3183 [==============================] - 140s - loss: 2.4310 - acc: 0.1543 - val_loss: 2.5618 - val_acc: 0.1106
Epoch 66/200
3183/3183 [==============================] - 140s - loss: 2.4234 - acc: 0.1517 - val_loss: 2.5579 - val_acc: 0.1143
Epoch 67/200
3183/3183 [==============================] - 141s - loss: 2.4313 - acc: 0.1486 - val_loss: 2.5572 - val_acc: 0.1206
Epoch 68/200
3183/3183 [==============================] - 140s - loss: 2.4115 - acc: 0.1621 - val_loss: 2.5575 - val_acc: 0.1206
Epoch 69/200
3183/3183 [==============================] - 140s - loss: 2.4146 - acc: 0.1653 - val_loss: 2.5571 - val_acc: 0.1118
Epoch 70/200
3183/3183 [==============================] - 140s - loss: 2.4113 - acc: 0.1649 - val_loss: 2.5557 - val_acc: 0.1156
Epoch 71/200
3183/3183 [==============================] - 141s - loss: 2.4110 - acc: 0.1659 - val_loss: 2.5671 - val_acc: 0.1106
Epoch 72/200
3183/3183 [==============================] - 141s - loss: 2.4058 - acc: 0.1621 - val_loss: 2.5662 - val_acc: 0.1118
Epoch 73/200
3183/3183 [==============================] - 141s - loss: 2.3949 - acc: 0.1728 - val_loss: 2.5600 - val_acc: 0.1055
Epoch 74/200
3183/3183 [==============================] - 141s - loss: 2.4068 - acc: 0.1643 - val_loss: 2.5609 - val_acc: 0.1030
Epoch 75/200
3183/3183 [==============================] - 140s - loss: 2.4111 - acc: 0.1599 - val_loss: 2.5601 - val_acc: 0.1093
Epoch 76/200
3183/3183 [==============================] - 141s - loss: 2.4008 - acc: 0.1640 - val_loss: 2.5650 - val_acc: 0.1080
Epoch 77/200
3183/3183 [==============================] - 140s - loss: 2.3973 - acc: 0.1703 - val_loss: 2.5616 - val_acc: 0.1143
Epoch 78/200
3183/3183 [==============================] - 140s - loss: 2.3984 - acc: 0.1675 - val_loss: 2.5546 - val_acc: 0.1193
Epoch 79/200
3183/3183 [==============================] - 140s - loss: 2.3999 - acc: 0.1700 - val_loss: 2.5539 - val_acc: 0.1156
Epoch 80/200
3183/3183 [==============================] - 141s - loss: 2.4028 - acc: 0.1659 - val_loss: 2.5601 - val_acc: 0.1181
Epoch 81/200
3183/3183 [==============================] - 140s - loss: 2.3951 - acc: 0.1715 - val_loss: 2.5552 - val_acc: 0.1231
Epoch 82/200
3183/3183 [==============================] - 140s - loss: 2.3940 - acc: 0.1725 - val_loss: 2.5582 - val_acc: 0.1219
Epoch 83/200
3183/3183 [==============================] - 140s - loss: 2.4001 - acc: 0.1665 - val_loss: 2.5545 - val_acc: 0.1193
Epoch 84/200
3183/3183 [==============================] - 141s - loss: 2.3924 - acc: 0.1668 - val_loss: 2.5549 - val_acc: 0.1156
Epoch 85/200
3183/3183 [==============================] - 140s - loss: 2.3945 - acc: 0.1697 - val_loss: 2.5555 - val_acc: 0.1206
Epoch 86/200
3183/3183 [==============================] - 141s - loss: 2.3906 - acc: 0.1668 - val_loss: 2.5538 - val_acc: 0.1231
Epoch 87/200
3183/3183 [==============================] - 140s - loss: 2.3926 - acc: 0.1687 - val_loss: 2.5501 - val_acc: 0.1231
Epoch 88/200
3183/3183 [==============================] - 140s - loss: 2.3990 - acc: 0.1634 - val_loss: 2.5490 - val_acc: 0.1131
Epoch 89/200
3183/3183 [==============================] - 140s - loss: 2.3871 - acc: 0.1756 - val_loss: 2.5502 - val_acc: 0.1244
Epoch 90/200
3183/3183 [==============================] - 141s - loss: 2.4054 - acc: 0.1640 - val_loss: 2.5487 - val_acc: 0.1244
Epoch 91/200
3183/3183 [==============================] - 141s - loss: 2.3918 - acc: 0.1637 - val_loss: 2.5590 - val_acc: 0.1168
Epoch 92/200
3183/3183 [==============================] - 141s - loss: 2.3890 - acc: 0.1684 - val_loss: 2.5568 - val_acc: 0.1131
Epoch 93/200
3183/3183 [==============================] - 141s - loss: 2.3776 - acc: 0.1728 - val_loss: 2.5530 - val_acc: 0.1256
Epoch 94/200
3183/3183 [==============================] - 141s - loss: 2.3848 - acc: 0.1728 - val_loss: 2.5605 - val_acc: 0.1244
Epoch 95/200
3183/3183 [==============================] - 141s - loss: 2.3845 - acc: 0.1653 - val_loss: 2.5673 - val_acc: 0.1168
Epoch 96/200
3183/3183 [==============================] - 141s - loss: 2.3747 - acc: 0.1766 - val_loss: 2.5674 - val_acc: 0.1156
Epoch 97/200
3183/3183 [==============================] - 141s - loss: 2.3907 - acc: 0.1609 - val_loss: 2.5632 - val_acc: 0.1168
Epoch 98/200
3183/3183 [==============================] - 141s - loss: 2.3844 - acc: 0.1740 - val_loss: 2.5592 - val_acc: 0.1181
Epoch 99/200
3183/3183 [==============================] - 140s - loss: 2.3887 - acc: 0.1631 - val_loss: 2.5578 - val_acc: 0.1080
Epoch 100/200
3183/3183 [==============================] - 141s - loss: 2.3828 - acc: 0.1659 - val_loss: 2.5544 - val_acc: 0.1143
Epoch 101/200
3183/3183 [==============================] - 141s - loss: 2.3811 - acc: 0.1693 - val_loss: 2.5548 - val_acc: 0.1143
Epoch 102/200
3183/3183 [==============================] - 140s - loss: 2.3785 - acc: 0.1722 - val_loss: 2.5561 - val_acc: 0.1143
Epoch 103/200
3183/3183 [==============================] - 140s - loss: 2.3950 - acc: 0.1637 - val_loss: 2.5639 - val_acc: 0.1168
Epoch 104/200
3183/3183 [==============================] - 140s - loss: 2.3817 - acc: 0.1725 - val_loss: 2.5555 - val_acc: 0.1106
Epoch 105/200
3183/3183 [==============================] - 141s - loss: 2.3756 - acc: 0.1800 - val_loss: 2.5558 - val_acc: 0.1143
Epoch 106/200
3183/3183 [==============================] - 141s - loss: 2.3700 - acc: 0.1800 - val_loss: 2.5538 - val_acc: 0.1181
Epoch 107/200
3183/3183 [==============================] - 140s - loss: 2.3825 - acc: 0.1706 - val_loss: 2.5538 - val_acc: 0.1055
Epoch 108/200
3183/3183 [==============================] - 141s - loss: 2.3802 - acc: 0.1734 - val_loss: 2.5646 - val_acc: 0.1206
Epoch 109/200
3183/3183 [==============================] - 141s - loss: 2.3838 - acc: 0.1740 - val_loss: 2.5575 - val_acc: 0.1193
Epoch 110/200
3183/3183 [==============================] - 140s - loss: 2.3916 - acc: 0.1747 - val_loss: 2.5597 - val_acc: 0.1156
Epoch 111/200
3183/3183 [==============================] - 140s - loss: 2.3778 - acc: 0.1722 - val_loss: 2.5629 - val_acc: 0.1131
Epoch 112/200
3183/3183 [==============================] - 141s - loss: 2.3635 - acc: 0.1772 - val_loss: 2.5636 - val_acc: 0.1156
Epoch 113/200
3183/3183 [==============================] - 140s - loss: 2.3724 - acc: 0.1769 - val_loss: 2.5650 - val_acc: 0.1156
Epoch 114/200
3183/3183 [==============================] - 140s - loss: 2.3732 - acc: 0.1744 - val_loss: 2.5667 - val_acc: 0.1181
Epoch 115/200
3183/3183 [==============================] - 140s - loss: 2.3813 - acc: 0.1822 - val_loss: 2.5596 - val_acc: 0.1156
Epoch 116/200
3183/3183 [==============================] - 141s - loss: 2.3604 - acc: 0.1753 - val_loss: 2.5612 - val_acc: 0.1206
Epoch 117/200
3183/3183 [==============================] - 140s - loss: 2.3712 - acc: 0.1737 - val_loss: 2.5583 - val_acc: 0.1181
Epoch 118/200
3183/3183 [==============================] - 140s - loss: 2.3705 - acc: 0.1687 - val_loss: 2.5584 - val_acc: 0.1219
Epoch 119/200
3183/3183 [==============================] - 140s - loss: 2.3645 - acc: 0.1847 - val_loss: 2.5659 - val_acc: 0.1106
Epoch 120/200
3183/3183 [==============================] - 141s - loss: 2.3662 - acc: 0.1747 - val_loss: 2.5596 - val_acc: 0.1080
Epoch 121/200
3183/3183 [==============================] - 141s - loss: 2.3660 - acc: 0.1803 - val_loss: 2.5572 - val_acc: 0.1093
Epoch 122/200
3183/3183 [==============================] - 141s - loss: 2.3620 - acc: 0.1844 - val_loss: 2.5542 - val_acc: 0.1156
Epoch 123/200
3183/3183 [==============================] - 140s - loss: 2.3562 - acc: 0.1791 - val_loss: 2.5531 - val_acc: 0.1143
Epoch 124/200
3183/3183 [==============================] - 141s - loss: 2.3649 - acc: 0.1756 - val_loss: 2.5596 - val_acc: 0.1106
Epoch 125/200
3183/3183 [==============================] - 141s - loss: 2.3559 - acc: 0.1850 - val_loss: 2.5536 - val_acc: 0.1181
Epoch 126/200
3183/3183 [==============================] - 140s - loss: 2.3720 - acc: 0.1775 - val_loss: 2.5570 - val_acc: 0.1156
Epoch 127/200
3183/3183 [==============================] - 141s - loss: 2.3646 - acc: 0.1719 - val_loss: 2.5533 - val_acc: 0.1206
Epoch 128/200
3183/3183 [==============================] - 141s - loss: 2.3636 - acc: 0.1759 - val_loss: 2.5484 - val_acc: 0.1281
Epoch 129/200
3183/3183 [==============================] - 140s - loss: 2.3572 - acc: 0.1832 - val_loss: 2.5629 - val_acc: 0.1181
Epoch 130/200
3183/3183 [==============================] - 140s - loss: 2.3555 - acc: 0.1863 - val_loss: 2.5625 - val_acc: 0.1118
Epoch 131/200
3183/3183 [==============================] - 141s - loss: 2.3562 - acc: 0.1740 - val_loss: 2.5598 - val_acc: 0.1118
Epoch 132/200
3183/3183 [==============================] - 140s - loss: 2.3662 - acc: 0.1781 - val_loss: 2.5529 - val_acc: 0.1156
Epoch 133/200
3183/3183 [==============================] - 141s - loss: 2.3710 - acc: 0.1847 - val_loss: 2.5480 - val_acc: 0.1206
Epoch 134/200
3183/3183 [==============================] - 140s - loss: 2.3521 - acc: 0.1885 - val_loss: 2.5553 - val_acc: 0.1193
Epoch 135/200
3183/3183 [==============================] - 140s - loss: 2.3494 - acc: 0.1791 - val_loss: 2.5569 - val_acc: 0.1106
Epoch 136/200
3183/3183 [==============================] - 140s - loss: 2.3634 - acc: 0.1756 - val_loss: 2.5512 - val_acc: 0.1231
Epoch 137/200
3183/3183 [==============================] - 140s - loss: 2.3502 - acc: 0.1872 - val_loss: 2.5478 - val_acc: 0.1256
Epoch 138/200
3183/3183 [==============================] - 140s - loss: 2.3469 - acc: 0.1847 - val_loss: 2.5544 - val_acc: 0.1219
Epoch 139/200
3183/3183 [==============================] - 140s - loss: 2.3540 - acc: 0.1872 - val_loss: 2.5492 - val_acc: 0.1206
Epoch 140/200
3183/3183 [==============================] - 140s - loss: 2.3520 - acc: 0.1885 - val_loss: 2.5573 - val_acc: 0.1294
Epoch 141/200
3183/3183 [==============================] - 140s - loss: 2.3539 - acc: 0.1832 - val_loss: 2.5551 - val_acc: 0.1231
Epoch 142/200
3183/3183 [==============================] - 141s - loss: 2.3581 - acc: 0.1835 - val_loss: 2.5491 - val_acc: 0.1319
Epoch 143/200
3183/3183 [==============================] - 141s - loss: 2.3670 - acc: 0.1794 - val_loss: 2.5498 - val_acc: 0.1256
Epoch 144/200
3183/3183 [==============================] - 141s - loss: 2.3573 - acc: 0.1872 - val_loss: 2.5559 - val_acc: 0.1256
Epoch 145/200
3183/3183 [==============================] - 141s - loss: 2.3617 - acc: 0.1788 - val_loss: 2.5533 - val_acc: 0.1193
Epoch 146/200
3183/3183 [==============================] - 140s - loss: 2.3499 - acc: 0.1885 - val_loss: 2.5587 - val_acc: 0.1281
Epoch 147/200
3183/3183 [==============================] - 141s - loss: 2.3547 - acc: 0.1835 - val_loss: 2.5596 - val_acc: 0.1269
Epoch 148/200
3183/3183 [==============================] - 141s - loss: 2.3549 - acc: 0.1844 - val_loss: 2.5525 - val_acc: 0.1382
Epoch 149/200
3183/3183 [==============================] - 140s - loss: 2.3679 - acc: 0.1794 - val_loss: 2.5536 - val_acc: 0.1256
Epoch 150/200
3183/3183 [==============================] - 140s - loss: 2.3559 - acc: 0.1872 - val_loss: 2.5558 - val_acc: 0.1206
Epoch 151/200
3183/3183 [==============================] - 140s - loss: 2.3540 - acc: 0.1781 - val_loss: 2.5558 - val_acc: 0.1231
Epoch 152/200
3183/3183 [==============================] - 140s - loss: 2.3518 - acc: 0.1781 - val_loss: 2.5527 - val_acc: 0.1294
Epoch 153/200
3183/3183 [==============================] - 140s - loss: 2.3500 - acc: 0.1762 - val_loss: 2.5478 - val_acc: 0.1319
Epoch 154/200
3183/3183 [==============================] - 141s - loss: 2.3427 - acc: 0.1885 - val_loss: 2.5506 - val_acc: 0.1231
Epoch 155/200
3183/3183 [==============================] - 141s - loss: 2.3469 - acc: 0.1803 - val_loss: 2.5574 - val_acc: 0.1206
Epoch 156/200
3183/3183 [==============================] - 141s - loss: 2.3561 - acc: 0.1791 - val_loss: 2.5546 - val_acc: 0.1319
Epoch 157/200
3183/3183 [==============================] - 141s - loss: 2.3580 - acc: 0.1872 - val_loss: 2.5478 - val_acc: 0.1344
Epoch 158/200
3183/3183 [==============================] - 140s - loss: 2.3487 - acc: 0.1810 - val_loss: 2.5526 - val_acc: 0.1231
Epoch 159/200
3183/3183 [==============================] - 140s - loss: 2.3412 - acc: 0.1822 - val_loss: 2.5602 - val_acc: 0.1181
Epoch 160/200
3183/3183 [==============================] - 140s - loss: 2.3507 - acc: 0.1759 - val_loss: 2.5533 - val_acc: 0.1181
Epoch 161/200
3183/3183 [==============================] - 140s - loss: 2.3422 - acc: 0.1967 - val_loss: 2.5516 - val_acc: 0.1193
Epoch 162/200
3183/3183 [==============================] - 141s - loss: 2.3404 - acc: 0.1854 - val_loss: 2.5613 - val_acc: 0.1219
Epoch 163/200
3183/3183 [==============================] - 141s - loss: 2.3288 - acc: 0.1979 - val_loss: 2.5603 - val_acc: 0.1118
Epoch 164/200
3183/3183 [==============================] - 140s - loss: 2.3450 - acc: 0.1885 - val_loss: 2.5613 - val_acc: 0.1219
Epoch 165/200
3183/3183 [==============================] - 140s - loss: 2.3526 - acc: 0.1841 - val_loss: 2.5465 - val_acc: 0.1281
Epoch 166/200
3183/3183 [==============================] - 140s - loss: 2.3401 - acc: 0.1885 - val_loss: 2.5528 - val_acc: 0.1244
Epoch 167/200
3183/3183 [==============================] - 140s - loss: 2.3534 - acc: 0.1803 - val_loss: 2.5443 - val_acc: 0.1219
Epoch 168/200
3183/3183 [==============================] - 141s - loss: 2.3350 - acc: 0.1891 - val_loss: 2.5610 - val_acc: 0.1055
Epoch 169/200
3183/3183 [==============================] - 140s - loss: 2.3405 - acc: 0.1788 - val_loss: 2.5551 - val_acc: 0.1080
Epoch 170/200
3183/3183 [==============================] - 141s - loss: 2.3404 - acc: 0.1781 - val_loss: 2.5598 - val_acc: 0.1143
Epoch 171/200
3183/3183 [==============================] - 141s - loss: 2.3353 - acc: 0.1844 - val_loss: 2.5474 - val_acc: 0.1244
Epoch 172/200
3183/3183 [==============================] - 141s - loss: 2.3437 - acc: 0.1894 - val_loss: 2.5507 - val_acc: 0.1206
Epoch 173/200
3183/3183 [==============================] - 141s - loss: 2.3390 - acc: 0.1901 - val_loss: 2.5563 - val_acc: 0.1206
Epoch 174/200
3183/3183 [==============================] - 141s - loss: 2.3384 - acc: 0.1876 - val_loss: 2.5486 - val_acc: 0.1206
Epoch 175/200
3183/3183 [==============================] - 141s - loss: 2.3211 - acc: 0.1894 - val_loss: 2.5444 - val_acc: 0.1193
Epoch 176/200
3183/3183 [==============================] - 141s - loss: 2.3403 - acc: 0.1806 - val_loss: 2.5498 - val_acc: 0.1206
Epoch 177/200
3183/3183 [==============================] - 140s - loss: 2.3460 - acc: 0.1882 - val_loss: 2.5440 - val_acc: 0.1244
Epoch 178/200
3183/3183 [==============================] - 140s - loss: 2.3433 - acc: 0.1860 - val_loss: 2.5478 - val_acc: 0.1256
Epoch 179/200
3183/3183 [==============================] - 140s - loss: 2.3332 - acc: 0.1869 - val_loss: 2.5514 - val_acc: 0.1131
Epoch 180/200
3183/3183 [==============================] - 140s - loss: 2.3420 - acc: 0.1891 - val_loss: 2.5482 - val_acc: 0.1256
Epoch 181/200
3183/3183 [==============================] - 140s - loss: 2.3441 - acc: 0.1907 - val_loss: 2.5438 - val_acc: 0.1319
Epoch 182/200
3183/3183 [==============================] - 141s - loss: 2.3424 - acc: 0.1929 - val_loss: 2.5465 - val_acc: 0.1332
Epoch 183/200
3183/3183 [==============================] - 141s - loss: 2.3284 - acc: 0.1998 - val_loss: 2.5500 - val_acc: 0.1193
Epoch 184/200
3183/3183 [==============================] - 141s - loss: 2.3311 - acc: 0.1835 - val_loss: 2.5420 - val_acc: 0.1294
Epoch 185/200
3183/3183 [==============================] - 141s - loss: 2.3476 - acc: 0.1794 - val_loss: 2.5490 - val_acc: 0.1294
Epoch 186/200
3183/3183 [==============================] - 141s - loss: 2.3333 - acc: 0.1888 - val_loss: 2.5347 - val_acc: 0.1319
Epoch 187/200
3183/3183 [==============================] - 141s - loss: 2.3247 - acc: 0.1935 - val_loss: 2.5447 - val_acc: 0.1256
Epoch 188/200
3183/3183 [==============================] - 142s - loss: 2.3281 - acc: 0.1929 - val_loss: 2.5445 - val_acc: 0.1256
Epoch 189/200
3183/3183 [==============================] - 141s - loss: 2.3309 - acc: 0.1869 - val_loss: 2.5593 - val_acc: 0.1256
Epoch 190/200
3183/3183 [==============================] - 141s - loss: 2.3359 - acc: 0.1872 - val_loss: 2.5464 - val_acc: 0.1332
Epoch 191/200
3183/3183 [==============================] - 141s - loss: 2.3320 - acc: 0.1942 - val_loss: 2.5531 - val_acc: 0.1281
Epoch 192/200
3183/3183 [==============================] - 141s - loss: 2.3327 - acc: 0.1844 - val_loss: 2.5526 - val_acc: 0.1256
Epoch 193/200
3183/3183 [==============================] - 141s - loss: 2.3414 - acc: 0.1869 - val_loss: 2.5504 - val_acc: 0.1244
Epoch 194/200
3183/3183 [==============================] - 141s - loss: 2.3485 - acc: 0.1913 - val_loss: 2.5466 - val_acc: 0.1332
Epoch 195/200
3183/3183 [==============================] - 141s - loss: 2.3357 - acc: 0.1898 - val_loss: 2.5492 - val_acc: 0.1269
Epoch 196/200
3183/3183 [==============================] - 141s - loss: 2.3342 - acc: 0.1891 - val_loss: 2.5485 - val_acc: 0.1219
Epoch 197/200
3183/3183 [==============================] - 142s - loss: 2.3295 - acc: 0.1920 - val_loss: 2.5554 - val_acc: 0.1332
Epoch 198/200
3183/3183 [==============================] - 142s - loss: 2.3354 - acc: 0.1957 - val_loss: 2.5482 - val_acc: 0.1244
Epoch 199/200
3183/3183 [==============================] - 141s - loss: 2.3238 - acc: 0.1916 - val_loss: 2.5482 - val_acc: 0.1394
Epoch 200/200
3183/3183 [==============================] - 141s - loss: 2.3292 - acc: 0.1957 - val_loss: 2.5478 - val_acc: 0.1294
796/796 [==============================] - 13s    
Test score: 2.54778578803
Test accuracy: 0.129396987414

In [55]:
# save model
def save_model(model_filename) :
    print(model.to_json(), file=open(model_filename+'.json', 'w'))
    model.save_weights(model_filename+'.weight.h5')
save_model('../model/ssn,i200,rnn64,f128tanh,f40ce,200ep')

In [62]:
print(*hist.history['val_loss'], sep='\n', file=open('temp.txt', 'w'))

In [52]:
import numpy as np
ratio_train = 0.8
rand_seed = 1337

print('Loading data...')
def sparse_to_row(ivlist) :
    a = np.zeros(n_topics)
    for i, v in ivlist :
        a[i] = v
    return a
Y = np.array([sparse_to_row(lda[c]) for c in corpus])
X = np.array([np.array([code_model[ins]
        for i, ins in enumerate(ins_seq) if i<200
    ]) for ins_seq in l_ins
])

n = len(X)
n_train = int(ratio_train * n)
np.random.seed(rand_seed)
ind = np.random.permutation(n)
ind_train = ind[:n_train]
ind_test = ind[n_train:]
(X_train, Y_train) = X[ind_train], Y[ind_train]
(X_test, Y_test) = X[ind_test], Y[ind_test]

print('X_train shape:', X_train.shape)
print('Y_train shape:', Y_train.shape)
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')


Loading data...
X_train shape: (3183,)
Y_train shape: (3183, 40)
3183 train samples
796 test samples

In [14]:
import data
import topicmodeling as tp
import gensim as g

l_interested = list(set(data.get_entity_list(data.libinfo.interested_libs)))
print(len(l_interested))
l_code = list(map(data.get_code, l_interested))
l_doc = list(map(data.get_doc, l_interested))
l_doc = list(map(lambda doc : data.pdoc.extract(doc, stage=data.pdoc.ADVANCED), l_doc))


C:\Anaconda3\envs\cnn\lib\site-packages\sklearn\lda.py:4: DeprecationWarning: lda.LDA has been moved to discriminant_analysis.LinearDiscriminantAnalysis in 0.17 and will be removed in 0.19
  "in 0.17 and will be removed in 0.19", DeprecationWarning)
C:\Anaconda3\envs\cnn\lib\site-packages\sklearn\metrics\metrics.py:4: DeprecationWarning: sklearn.metrics.metrics is deprecated and will be removed in 0.18. Please import from sklearn.metrics
  DeprecationWarning)
C:\Anaconda3\envs\cnn\lib\site-packages\sklearn\qda.py:4: DeprecationWarning: qda.QDA has been moved to discriminant_analysis.QuadraticDiscriminantAnalysis in 0.17 and will be removed in 0.19.
  "in 0.17 and will be removed in 0.19.", DeprecationWarning)
3979

In [15]:
# In[34]:

documents = l_doc
texts = tp.simple_process(documents=documents, stoplist=tp.read_stoplist("../SmartStoplist.txt"))
texts = tp.remove_infrequent(texts, n_times=1)


# In[35]:

id2word = g.corpora.Dictionary(texts)
#id2word.save('/tmp/deerwester.dict') # store the id2word, for future reference
print(*list(id2word)[:10])


# In[36]:

corpus = [id2word.doc2bow(text) for text in texts]
#g.corpora.corpusCorpus.serialize('/tmp/deerwester.mm', corpus) # store to disk, for later use
print(*list(corpus)[:10])
tfidf = g.models.tfidfmodel.TfidfModel(corpus)
corpus = [tfidf[bag] for bag in corpus]
print(*list(corpus)[:10])


# In[ ]:

# load id->word mapping (the id2word), one of the results of step 2 above
#id2word = g.g.corpora.id2word.load_from_text('wiki_en_wordids.txt')
# load corpus iterator
#corpus = g.corpora.MmCorpus('/tmp/deerwester.mm')
#corpus = g.g.corpora.MmCorpus(bz2.BZ2File('wiki_en_tfidf.mm.bz2')) # use this if you compressed the TFIDF output (recommended)


# In[81]:

n_topics = 40
## extract 100 LDA topics, using 1 pass and updating once every 1 chunk (10,000 documents)
lda = g.models.ldamodel.LdaModel(corpus=corpus, id2word=id2word, num_topics=n_topics, 
                                 update_every=1, chunksize=10000, passes=5)
## print the most contributing words for n_topic topics
l = list(lda.print_topics(n_topics))
for i, s in l :
    print(i, *s.split(' + '))


685 27 471 682 467 476 1355 1201 368 1278
[(0, 1), (1, 1), (2, 1), (3, 1), (4, 1)] [(5, 1), (6, 1), (7, 1), (8, 1), (9, 1)] [(10, 2), (11, 1), (12, 1), (13, 1), (14, 1)] [(13, 1), (15, 1), (16, 1), (17, 1), (18, 1)] [(16, 1), (19, 1), (20, 1), (21, 1)] [(16, 1), (22, 1), (23, 1), (24, 1)] [(25, 1), (26, 1)] [(15, 1), (16, 1), (17, 1), (18, 1), (27, 1)] [(1, 1), (4, 1), (16, 1), (20, 1), (28, 1), (29, 1)] [(16, 1), (22, 1), (23, 1), (24, 1)]
[(0, 0.49399426709707445), (1, 0.43248535675580047), (2, 0.5739877358718446), (3, 0.44819876007159315), (4, 0.19642309132592412)] [(5, 0.46641051263332084), (6, 0.38002353185707194), (7, 0.44045878078803113), (8, 0.35597827854819025), (9, 0.563310639495286)] [(10, 0.49238052046400904), (11, 0.5162334263963622), (12, 0.5162334263963622), (13, 0.33984698427280347), (14, 0.3302598208841272)] [(13, 0.4631857093147272), (15, 0.5764470680488752), (16, 0.2193312382637294), (17, 0.41208740142097555), (18, 0.4850211932856909)] [(16, 0.3068937687430323), (19, 0.5298718859547009), (20, 0.521862600620502), (21, 0.5938951298462658)] [(16, 0.2444281370488579), (22, 0.5772121652177122), (23, 0.5573298383472616), (24, 0.5444854942340711)] [(25, 0.8875667100514567), (26, 0.4606792107404384)] [(15, 0.5094137925202258), (16, 0.19382587594774683), (17, 0.3641670114104303), (18, 0.4286195545423442), (27, 0.6217690753548032)] [(1, 0.45362216124891513), (4, 0.20602285329343398), (16, 0.22906910811584152), (20, 0.389524365296733), (28, 0.6345856622411846), (29, 0.38063193319767935)] [(16, 0.2444281370488579), (22, 0.5772121652177122), (23, 0.5573298383472616), (24, 0.5444854942340711)]
0 0.041*laguerr 0.028*repeat 0.026*center 0.025*norm 0.021*transfer 0.017*point 0.014*substr 0.014*frobeniu 0.014*valueerror 0.013*seri
1 0.046*kernel 0.040*bound 0.027*place 0.026*variat 0.023*lower 0.020*slice 0.019*callabl 0.018*comput 0.017*ensembl 0.015*matrix
2 0.039*matric 0.028*argument 0.027*spars 0.018*transform 0.016*friedman 0.016*space 0.015*directli 0.014*distanc 0.014*gener 0.013*function
3 0.068*log 0.061*probabl 0.050*densiti 0.042*covari 0.034*likelihood 0.034*estim 0.028*function 0.026*determin 0.024*model 0.018*comput
4 0.059*distribut 0.052*wishart 0.051*invers 0.026*creat 0.020*frozen 0.017*equat 0.016*matrix 0.015*predict_log_proba 0.015*period 0.015*sampl
5 0.029*instanc 0.029*linkag 0.022*multidimension 0.022*chebyshev 0.020*condens 0.020*quadrat 0.018*permut 0.018*convert 0.018*gaussian 0.015*seed
6 0.032*binari 0.028*ndarrai 0.027*cosin 0.021*type 0.018*cast 0.017*tupl 0.017*width 0.017*oper 0.016*arrai 0.015*length
7 0.135*wise 0.100*element 0.040*string 0.029*represent 0.028*reduct 0.024*dimension 0.018*unchang 0.015*appli 0.013*peak 0.013*store
8 0.037*set 0.035*system 0.030*attribut 0.020*represent 0.019*auxiliari 0.019*base_estimator_ 0.019*statespac 0.019*zerospolesgain 0.016*paramet 0.015*packag
9 0.058*axi 0.057*integr 0.048*maximum 0.046*valu 0.043*mask 0.032*indic 0.029*minimum 0.028*arrai 0.025*intern 0.021*definit
10 0.056*degre 0.056*vandermond 0.046*pseudo 0.028*matrix 0.023*equat 0.020*solv 0.020*exp 0.018*learn 0.017*curv 0.015*kei
11 0.031*underli 0.028*numpi 0.027*companion 0.024*limit 0.024*pickl 0.023*length 0.022*arrai 0.021*estim 0.020*pair 0.020*reduc
12 0.076*evalu 0.057*point 0.045*root 0.043*seri 0.033*legendr 0.033*find 0.030*spline 0.029*interpol 0.029*polynomi 0.024*axi
13 0.061*train 0.060*data 0.052*fit 0.037*model 0.024*convert 0.023*dimens 0.023*input 0.022*squar 0.021*mask 0.020*arrai
14 0.045*variabl 0.025*error 0.025*trim 0.024*minimum 0.023*function 0.021*document 0.018*stream 0.018*nonzero 0.017*comput 0.016*scalar
15 0.040*classif 0.027*vector 0.025*gener 0.024*project 0.022*line 0.020*random 0.019*matrix 0.019*perform 0.017*test 0.015*graph
16 0.027*run 0.026*iter 0.026*imag 0.023*dataset 0.022*load 0.021*file 0.017*absolut 0.017*step 0.017*estim 0.016*em
17 0.039*final 0.038*transform 0.034*estim 0.024*valid 0.023*bin 0.023*data 0.022*appli 0.021*reconstruct 0.019*transpos 0.018*predict_proba
18 0.039*dimension 0.036*fourier 0.031*hermit 0.029*discret 0.028*divid 0.028*seri 0.026*add 0.026*sort 0.025*differenti 0.024*construct
19 0.058*true 0.040*charact 0.034*element 0.030*score 0.029*test 0.026*string 0.026*check 0.023*fals 0.019*scalar 0.018*differ
20 0.046*fill 0.033*make 0.025*logarithm 0.024*singl 0.023*scale 0.018*fill_valu 0.017*default 0.016*item 0.016*current 0.015*mask
21 0.049*correl 0.036*privat 0.031*filter 0.025*function 0.025*calcul 0.024*cross 0.022*job 0.019*lowpass 0.018*coeffici 0.018*arrai
22 0.043*statist 0.027*stage 0.021*pass 0.021*decision_funct 0.021*descript 0.017*data 0.017*compat 0.012*arrai 0.012*swap 0.011*estim
23 0.047*function 0.047*multipli 0.037*decis 0.030*comput 0.026*nt 0.024*kelvin 0.021*flatten 0.021*helper 0.020*classifi 0.020*arrai
24 0.052*copi 0.050*row 0.037*format 0.036*matrix 0.033*column 0.028*spars 0.026*compress 0.025*pole 0.015*gain 0.013*coeffici
25 0.050*shape 0.047*arrai 0.034*mask 0.033*featur 0.025*record 0.025*equal 0.018*label 0.018*count 0.017*transform 0.016*subclass
26 0.072*class 0.072*predict 0.037*probabl 0.032*averag 0.019*axi 0.019*file 0.017*batch 0.017*sourc 0.017*biclust 0.017*label
27 0.041*approxim 0.032*quadratur 0.030*gauss 0.025*encod 0.019*jacobian 0.018*label 0.017*map 0.015*size 0.013*pade 0.012*equival
28 0.081*read 0.062*header 0.032*bit 0.026*integ 0.020*exponenti 0.019*write 0.018*matrix 0.016*posit 0.016*sign 0.013*byte
29 0.042*updat 0.035*varianc 0.030*region 0.028*distribut 0.022*version 0.020*dictionari 0.018*dirichlet 0.017*gauss 0.017*quadratur 0.016*binar
30 0.048*neg 0.027*solver 0.026*partial 0.025*comput 0.025*rang 0.021*ham 0.019*output 0.018*sampl 0.017*expans 0.013*geometr
31 0.077*call 0.038*paramet 0.038*found 0.037*estim 0.024*execut 0.022*behavior 0.020*part 0.019*metric 0.015*ensur 0.014*predict
32 0.042*bessel 0.042*structur 0.037*deriv 0.025*field 0.024*concaten 0.022*main 0.021*function 0.020*arrai 0.018*mode 0.017*diagon
33 0.050*diagon 0.026*arrai 0.026*ax 0.023*symmetr 0.023*eigenvalu 0.021*linearoper 0.020*hermitian 0.018*posit 0.016*matrix 0.015*save
34 0.080*fit 0.079*model 0.053*data 0.039*initi 0.028*condit 0.025*transform 0.024*coordin 0.019*origin 0.018*back 0.017*func
35 0.050*window 0.049*median 0.037*order 0.027*filter 0.024*histogram 0.023*digit 0.021*analog 0.019*contigu 0.016*chebyshev 0.014*blackman
36 0.039*product 0.031*cartesian 0.028*time 0.026*respons 0.022*print 0.022*option 0.020*arrai 0.020*system 0.018*continu 0.018*common
37 0.062*distanc 0.056*cluster 0.054*arrai 0.049*comput 0.038*boolean 0.024*dissimilar 0.023*euclidean 0.023*split 0.014*true 0.014*mask
38 0.057*matrix 0.047*random 0.034*complex 0.030*sum 0.030*rank 0.029*precis 0.027*id 0.027*svd 0.027*comput 0.026*real
39 0.033*polynomi 0.033*regress 0.029*target 0.028*rais 0.027*power 0.026*linear 0.025*weight 0.024*orthogon 0.022*function 0.020*predict

In [38]:
import dis
import itertools

l_ins = [[i.opname for i in dis.get_instructions(code)]
        for code in l_code]
all_ins = list(itertools.chain(*l_ins))
print(len(all_ins))
all_ins_unique = set(all_ins)
print(len(all_ins_unique))

code_model = g.models.word2vec.Word2Vec(l_ins, size=8, window=10, min_count=0, workers=4, seed=1337, iter=20)


339060
84

In [72]:
lll = lda
#index = g.similarities.SparseMatrixSimilarity(lll[corpus], num_features=22)
index = g.similarities.Similarity(None, lll[corpus], num_features=n_topics)

In [81]:
i_doc = 1959
text = texts[i_doc]
sims = index[lll[tfidf[id2word.doc2bow(text)]]]
print(i_doc, '', texts[i_doc], sep=' | ')
for i, score in list(sorted(enumerate(sims), key=lambda t : t[1], reverse=True))[:10] :
    print(i, "%.3f"%score, texts[i], sep=' | ')


1959 |  | ['popul', 'random', 'type', 'initi', 'cluster', 'latin', 'hypercub', 'sampl', 'gener']
587 | 1.000 | ['comput', 'neg', 'gradient']
853 | 1.000 | ['comput', 'partial', 'fraction', 'expans']
861 | 1.000 | ['comput', 'partial', 'fraction', 'expans']
869 | 1.000 | ['comput', 'partial', 'fraction', 'expans']
876 | 1.000 | ['comput', 'partial', 'fraction', 'expans']
1613 | 1.000 | ['calcul', 'phase', 'gener', 'output']
1629 | 1.000 | ['calcul', 'phase', 'gener', 'output']
1819 | 1.000 | ['set', 'storag', 'index', 'locat', 'valu']
1959 | 1.000 | ['popul', 'random', 'type', 'initi', 'cluster', 'latin', 'hypercub', 'sampl', 'gener']
1967 | 1.000 | ['partial', 'depend', 'plot', 'featur']

In [74]:
print([i_doc for i_doc, doc in enumerate(l_doc) if 'cluster' in doc])


[221, 254, 265, 273, 280, 803, 814, 848, 854, 862, 870, 882, 899, 937, 945, 960, 1207, 1229, 1235, 1244, 1268, 1272, 1277, 1327, 1355, 1889, 1934, 1959, 2056, 2059, 2074, 2101, 2112, 2168, 2181, 2184, 2186, 2190, 2193, 2195, 2787, 2850, 2901, 3053, 3142, 3159, 3470, 3476, 3521, 3565, 3595, 3600, 3604, 3610, 3619, 3655, 3772, 3775, 3784, 3791, 3794, 3806, 3815, 3823, 3825, 3830, 3833, 3840]

In [ ]: