In [1]:
import numpy as np
from distutils.version import LooseVersion
import warnings
import tensorflow as tf
import gensim

# Check TensorFlow Version
assert LooseVersion(tf.__version__) >= LooseVersion('1.1'), 'Please use TensorFlow version 1.1 or newer'
print('TensorFlow Version: {}'.format(tf.__version__))

# Check for a GPU
if not tf.test.gpu_device_name():
    warnings.warn('No GPU found. Please use a GPU to train your neural network.')
else:
    print('Default GPU Device: {}'.format(tf.test.gpu_device_name()))


TensorFlow Version: 1.3.0
Default GPU Device: /gpu:0

In [2]:
# Load word2vec model
w2v = gensim.models.KeyedVectors.load('data/w2v-773752559-1000000-300-5-5-OpenSubtitles2016.bin')

In [3]:
def get_inputs(output_dim=300):
    """
    Create TF Placeholders for input, targets, learning_rate and input_sequence_length.
    :return: Tuple (input_, targets, learning_rate, keep_prob, input_sequence_length)
    """

    input_ = tf.placeholder(tf.int32, [None, None], name='input')
    targets = tf.placeholder(tf.float32, [None, output_dim])
    learning_rate = tf.placeholder(tf.float32)
    keep_prob = tf.placeholder(tf.float32, name='keep_prob')
    input_sequence_length = tf.placeholder(tf.int32, [None], name="input_sequence_length")
    
    return (input_, targets, learning_rate, keep_prob, input_sequence_length)

In [4]:
def build_lstm(lstm_size, num_layers, batch_size, keep_prob, inputs, num_classes, input_sequence_length):
    ''' Build LSTM cell.
    
        Arguments
        ---------
        keep_prob: Scalar tensor (tf.placeholder) for the dropout keep probability
        lstm_size: Size of the hidden layers in the LSTM cells
        num_layers: Number of LSTM layers
        batch_size: Batch size

    '''
    
    # one_hot encode input
    x_one_hot = tf.one_hot(inputs, num_classes) # num_classes = len(vocab)
    
    def build_cell(rnn_size):
        cell = tf.contrib.rnn.LSTMCell(rnn_size, initializer=tf.contrib.layers.xavier_initializer())
        return cell
    
    # Construct a stacked tf.contrib.rnn.LSTMCell...
    stacked_cell = tf.contrib.rnn.MultiRNNCell([build_cell(lstm_size) for _ in range(num_layers)])
    # ...wrapped in a tf.contrib.rnn.DropoutWrapper
    cell = tf.contrib.rnn.DropoutWrapper(stacked_cell, output_keep_prob=keep_prob)
    
    # Pass cell and embedded input to tf.nn.dynamic_rnn()
    rnn_output, rnn_state = tf.nn.dynamic_rnn(cell, x_one_hot, sequence_length=input_sequence_length, dtype=tf.float32)
    
    # Initial state
    initial_state = tf.identity(stacked_cell.zero_state(batch_size, tf.float32), name="initial_state")
    
    return rnn_output, rnn_state, initial_state

We only care about the final rnn cell output. So we need to grab it with outputs[:, -1].


In [5]:
def build_output(cell, keep_prob, hidden_dim=1024, output_dim=300):
    input_ = cell[:, -1]
    dense = tf.contrib.layers.fully_connected(inputs=input_, num_outputs=hidden_dim, activation_fn=tf.nn.tanh)
    # dense = tf.nn.dropout(dense, keep_prob)
    dense = tf.layers.batch_normalization(dense)
    return tf.contrib.layers.fully_connected(dense, num_outputs=output_dim, activation_fn=None)
#     return tf.contrib.layers.fully_connected(inputs=cell[:, -1], num_outputs=output_dim, activation_fn=tf.nn.relu)

In [6]:
def get_loss(pred, Y):
    pred=tf.nn.l2_normalize(pred,0)
    Y=tf.nn.l2_normalize(Y,0)
    return tf.reduce_mean (1 - tf.reduce_sum(tf.multiply(pred,Y), axis=(1,), keep_dims=True))
#     return tf.losses.cosine_distance(pred, Y, dim=1)

In [7]:
# build vocabulary
vocab = sorted(set(" ".join(w2v.wv.index2word)))
vocab_to_int = {c: i for i, c in enumerate(vocab,1 )}

In [8]:
batch_size=256
lstm_size=1024
num_layers=3
keep_probability=0.8
num_classes=len(vocab)
output_dim=300
# learning_rate=0.001
learning_rate=0.0005
save_dir = './model/seq2vec'

# Create the graph object
graph = tf.Graph()

with graph.as_default():
    (input_, targets, lr, keep_prob, input_sequence_length) = get_inputs()
    with tf.variable_scope('LSTM'):
        rnn_output, rnn_state, initial_state = build_lstm(lstm_size, num_layers, batch_size, keep_prob, input_, num_classes, input_sequence_length)
    with tf.variable_scope('OUTPUT'):
        output = build_output(rnn_output, keep_prob, output_dim)
    with tf.variable_scope('LOSS'):
        loss = get_loss(output, targets)
    
    with tf.variable_scope('OPTIMIZER'):
        # Optimizer
        optimizer = tf.train.AdamOptimizer(learning_rate)

        gradients = optimizer.compute_gradients(loss)
        # clip gradients
        capped_gradients = [(tf.clip_by_value(grad, -1., 1.), var) for grad, var in gradients if grad is not None]
        train_op = optimizer.apply_gradients(capped_gradients)
    
with tf.Session(graph=graph) as sess:
    #writer = tf.summary.FileWriter("log", sess.graph)
    sess.run(tf.global_variables_initializer())

    # tensorboard summary scalar for loss
    tf.summary.scalar("loss", loss)
    merged_summary = tf.summary.merge_all()
    
    # tensorboard writer graph
    writer=tf.summary.FileWriter("writer/1")
    writer.add_graph(sess.graph)
    
    # Save session
    saver = tf.train.Saver()
    saver.save(sess, save_dir, global_step=0)

#     # Optimizer for training, using gradient clipping to control exploding gradients
#     tvars = tf.trainable_variables()
#     grads, _ = tf.clip_by_global_norm(tf.gradients(loss, tvars), grad_clip)
#     train_op = tf.train.AdamOptimizer(learning_rate)
#     optimizer = train_op.apply_gradients(zip(grads, tvars))

In [9]:
# def word2seq(word):
#     return np.array([vocab_to_int.get(c,0) for c in word])

# def get_train_subset(model=w2v, seed_words=500, topn=7):
#     top_words=model.wv.index2word[120:120+seed_words]
#     top_words=np.append(np.array(top_words),np.array([np.array(model.most_similar_cosmul(w, topn=topn))[:,0] for w in top_words]))
#     top_words=top_words.flatten()
#     top_words=set(top_words)
#     return top_words

# input_list=list(get_train_subset())

In [10]:
def get_padded_int_batch(input_batch, vocab_to_int=vocab_to_int):
    max_len = max([len(word) for word in input_batch])
    int_batch =  [[0] * (max_len - len(l)) + [vocab_to_int[w] for w in l] for l in input_batch]
    return int_batch


def get_batch(input_list=w2v.wv.index2word, batch_size=batch_size, vocab=vocab, vocab_to_int=vocab_to_int, model=w2v):
    """
    Batch generator.
    Input: train_set - list of words
    Returns touple:
    (pad_input_batch, pad_input_lengths, output_batch)
    """
    for batch_i in range(0, len(input_list)//batch_size):
        start_i = batch_i * batch_size

        # Slice the right amount for the batch
        input_batch = input_list[start_i:start_i + batch_size]

        # Pad
        pad_input_batch = np.array(get_padded_int_batch(input_batch, vocab_to_int))

        # Need the lengths for the _lengths parameters
        pad_input_lengths = []
        for line in pad_input_batch:
            pad_input_lengths.append(len(line))
    
        # output batch
        output_batch=np.array([w2v.wv.word_vec(w) for w in input_batch])

        yield (pad_input_batch, pad_input_lengths, output_batch)


# for (batch_i, (pad_input_batch, pad_input_lengths, output)) in enumerate(get_batch(w2v.wv.index2word[:1000], batch_size=50)):
#     print (batch_i)
#     pass

train_size= 6000#00

train_input = w2v.wv.index2word[:train_size]
valid_input = w2v.wv.index2word[train_size:train_size+batch_size]

In [12]:
#%%time
#tf.reset_default_graph()

num_epochs=100
display_step=1


with tf.Session(graph=graph) as sess:
#     sess.run(tf.global_variables_initializer())
#     loader = tf.train.import_meta_graph(save_dir + '.meta')
    saver.restore(sess, saver.last_checkpoints[-1])

    for epoch_i in range(1, num_epochs):
        for batch_i, (pad_input_batch, pad_input_lengths, out_vec) in enumerate(get_batch(train_input)):

            _, _, l = sess.run([initial_state, train_op, loss], {
                input_: pad_input_batch,
                targets: out_vec,
                lr: learning_rate, 
                keep_prob: keep_probability,
                input_sequence_length: pad_input_lengths,
            })
        if (epoch_i % display_step ==0):
            (pad_input_batch, pad_input_lengths, out_vec) = next(get_batch(valid_input))
            _, valid_loss, ms = sess.run([initial_state, loss, merged_summary],{
                input_: pad_input_batch,
                targets: out_vec,
                keep_prob: 1.0,
                input_sequence_length: pad_input_lengths,
            })
            print("Epoch: {:3} | Loss: {:2.4}\t validation loss: {:2.4}".format(epoch_i, l, valid_loss))
            writer.add_summary(ms, epoch_i)
        # save model
        saver.save(sess, save_dir, global_step=epoch_i)


INFO:tensorflow:Restoring parameters from ./model/seq2vec-0
Epoch:   1 | Loss: 0.6965	 validation loss: 0.6934
Epoch:   2 | Loss: 0.6851	 validation loss: 0.6849
Epoch:   3 | Loss: 0.6849	 validation loss: 0.6854
Epoch:   4 | Loss: 0.6831	 validation loss: 0.6845
Epoch:   5 | Loss: 0.6801	 validation loss: 0.6828
Epoch:   6 | Loss: 0.68	 validation loss: 0.6828
Epoch:   7 | Loss: 0.6774	 validation loss: 0.6802
Epoch:   8 | Loss: 0.677	 validation loss: 0.6805
Epoch:   9 | Loss: 0.6776	 validation loss: 0.6803
Epoch:  10 | Loss: 0.6757	 validation loss: 0.6773
Epoch:  11 | Loss: 0.6765	 validation loss: 0.6741
Epoch:  12 | Loss: 0.669	 validation loss: 0.672
Epoch:  13 | Loss: 0.6656	 validation loss: 0.667
Epoch:  14 | Loss: 0.6679	 validation loss: 0.6645
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-12-af9a3843295a> in <module>()
     19                 lr: learning_rate,
     20                 keep_prob: keep_probability,
---> 21                 input_sequence_length: pad_input_lengths,
     22             })
     23         if (epoch_i % display_step ==0):

/anaconda3/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    893     try:
    894       result = self._run(None, fetches, feed_dict, options_ptr,
--> 895                          run_metadata_ptr)
    896       if run_metadata:
    897         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/anaconda3/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1122     if final_fetches or final_targets or (handle and feed_dict_tensor):
   1123       results = self._do_run(handle, final_targets, final_fetches,
-> 1124                              feed_dict_tensor, options, run_metadata)
   1125     else:
   1126       results = []

/anaconda3/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1319     if handle is None:
   1320       return self._do_call(_run_fn, self._session, feeds, fetches, targets,
-> 1321                            options, run_metadata)
   1322     else:
   1323       return self._do_call(_prun_fn, self._session, handle, feeds, fetches)

/anaconda3/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
   1325   def _do_call(self, fn, *args):
   1326     try:
-> 1327       return fn(*args)
   1328     except errors.OpError as e:
   1329       message = compat.as_text(e.message)

/anaconda3/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/client/session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
   1304           return tf_session.TF_Run(session, options,
   1305                                    feed_dict, fetch_list, target_list,
-> 1306                                    status, run_metadata)
   1307 
   1308     def _prun_fn(session, handle, feed_dict, fetch_list):

KeyboardInterrupt: 

In [19]:
saver.last_checkpoints[-1]
writer.add_graph(sess.graph)

In [20]:
%%time
tf.reset_default_graph()
learning_rate=0.00001

num_epochs=100
display_step=1

# tensorboard writer graph
writer=tf.summary.FileWriter("writer/3")

with tf.Session(graph=graph) as sess:
#     sess.run(tf.global_variables_initializer())
#    loader = tf.train.import_meta_graph(save_dir + '.meta')
    saver.restore(sess, saver.last_checkpoints[-1])

    for epoch_i in range(num_epochs):
        for batch_i, (pad_input_batch, pad_input_lengths, out_vec) in enumerate(get_batch(train_input)):

            _, _, l = sess.run([initial_state, train_op, loss], {
                input_: pad_input_batch,
                targets: out_vec,
                lr: learning_rate, 
                keep_prob: keep_probability,
                input_sequence_length: pad_input_lengths,
            })
        if (epoch_i % display_step ==0):
            (pad_input_batch, pad_input_lengths, out_vec) = next(get_batch(valid_input))
            _, valid_loss, ms = sess.run([initial_state, loss, merged_summary],{
                input_: pad_input_batch,
                targets: out_vec,
                keep_prob: 1.0,
                input_sequence_length: pad_input_lengths,
            })
            print("Epoch: {:3} | Loss: {:2.4}\t validation loss: {:2.4}".format(epoch_i, l, valid_loss))
            writer.add_summary(ms, epoch_i)
        # save model
        saver.save(sess, save_dir)


INFO:tensorflow:Restoring parameters from ./model/seq2vec
Epoch:   0 | Loss: 0.6172	 validation loss: 0.6232
Epoch:   1 | Loss: 0.6153	 validation loss: 0.6178
Epoch:   2 | Loss: 0.6096	 validation loss: 0.6153
Epoch:   3 | Loss: 0.6051	 validation loss: 0.6125
Epoch:   4 | Loss: 0.601	 validation loss: 0.6093
Epoch:   5 | Loss: 0.5958	 validation loss: 0.606
Epoch:   6 | Loss: 0.5926	 validation loss: 0.6031
Epoch:   7 | Loss: 0.5874	 validation loss: 0.6006
Epoch:   8 | Loss: 0.5817	 validation loss: 0.597
Epoch:   9 | Loss: 0.5738	 validation loss: 0.5933
Epoch:  10 | Loss: 0.5714	 validation loss: 0.5892
Epoch:  11 | Loss: 0.5612	 validation loss: 0.586
Epoch:  12 | Loss: 0.559	 validation loss: 0.5843
Epoch:  13 | Loss: 0.5493	 validation loss: 0.5778
Epoch:  14 | Loss: 0.5435	 validation loss: 0.5752
Epoch:  15 | Loss: 0.5339	 validation loss: 0.5707
Epoch:  16 | Loss: 0.524	 validation loss: 0.5711
Epoch:  17 | Loss: 0.5156	 validation loss: 0.5658
Epoch:  18 | Loss: 0.507	 validation loss: 0.5629
Epoch:  19 | Loss: 0.5108	 validation loss: 0.5638
Epoch:  20 | Loss: 0.4972	 validation loss: 0.56
Epoch:  21 | Loss: 0.4832	 validation loss: 0.5543
Epoch:  22 | Loss: 0.4674	 validation loss: 0.5585
Epoch:  23 | Loss: 0.4574	 validation loss: 0.5592
Epoch:  24 | Loss: 0.4563	 validation loss: 0.5591
Epoch:  25 | Loss: 0.4527	 validation loss: 0.552
Epoch:  26 | Loss: 0.4368	 validation loss: 0.5527
Epoch:  27 | Loss: 0.4308	 validation loss: 0.5491
Epoch:  28 | Loss: 0.4188	 validation loss: 0.5463
Epoch:  29 | Loss: 0.406	 validation loss: 0.5434
Epoch:  30 | Loss: 0.3961	 validation loss: 0.5438
Epoch:  31 | Loss: 0.391	 validation loss: 0.5375
Epoch:  32 | Loss: 0.3773	 validation loss: 0.5402
Epoch:  33 | Loss: 0.3695	 validation loss: 0.5385
Epoch:  34 | Loss: 0.3558	 validation loss: 0.5365
Epoch:  35 | Loss: 0.3432	 validation loss: 0.5347
Epoch:  36 | Loss: 0.3233	 validation loss: 0.5318
Epoch:  37 | Loss: 0.3139	 validation loss: 0.5294
Epoch:  38 | Loss: 0.3064	 validation loss: 0.533
Epoch:  39 | Loss: 0.2939	 validation loss: 0.5348
Epoch:  40 | Loss: 0.2827	 validation loss: 0.5313
Epoch:  41 | Loss: 0.2657	 validation loss: 0.5326
Epoch:  42 | Loss: 0.2581	 validation loss: 0.5325
Epoch:  43 | Loss: 0.2475	 validation loss: 0.5321
Epoch:  44 | Loss: 0.2383	 validation loss: 0.5329
Epoch:  45 | Loss: 0.231	 validation loss: 0.5277
Epoch:  46 | Loss: 0.2234	 validation loss: 0.5319
Epoch:  47 | Loss: 0.2149	 validation loss: 0.5306
Epoch:  48 | Loss: 0.1992	 validation loss: 0.5272
Epoch:  49 | Loss: 0.1952	 validation loss: 0.5296
Epoch:  50 | Loss: 0.1871	 validation loss: 0.5308
Epoch:  51 | Loss: 0.1882	 validation loss: 0.5285
Epoch:  52 | Loss: 0.1892	 validation loss: 0.5287
Epoch:  53 | Loss: 0.1787	 validation loss: 0.5281
Epoch:  54 | Loss: 0.1693	 validation loss: 0.5251
Epoch:  55 | Loss: 0.1639	 validation loss: 0.5236
Epoch:  56 | Loss: 0.1602	 validation loss: 0.5228
Epoch:  57 | Loss: 0.1522	 validation loss: 0.5245
Epoch:  58 | Loss: 0.1419	 validation loss: 0.5214
Epoch:  59 | Loss: 0.1376	 validation loss: 0.5229
Epoch:  60 | Loss: 0.1338	 validation loss: 0.5239
Epoch:  61 | Loss: 0.1295	 validation loss: 0.5244
Epoch:  62 | Loss: 0.1199	 validation loss: 0.5236
Epoch:  63 | Loss: 0.1132	 validation loss: 0.5245
Epoch:  64 | Loss: 0.1119	 validation loss: 0.5246
Epoch:  65 | Loss: 0.111	 validation loss: 0.5254
Epoch:  66 | Loss: 0.1251	 validation loss: 0.5249
Epoch:  67 | Loss: 0.1146	 validation loss: 0.5234
Epoch:  68 | Loss: 0.118	 validation loss: 0.5252
Epoch:  69 | Loss: 0.1018	 validation loss: 0.5266
Epoch:  70 | Loss: 0.1047	 validation loss: 0.5265
Epoch:  71 | Loss: 0.08491	 validation loss: 0.5272
Epoch:  72 | Loss: 0.08207	 validation loss: 0.5301
Epoch:  73 | Loss: 0.07518	 validation loss: 0.5319
Epoch:  74 | Loss: 0.07674	 validation loss: 0.5316
Epoch:  75 | Loss: 0.06545	 validation loss: 0.5291
Epoch:  76 | Loss: 0.06091	 validation loss: 0.5286
Epoch:  77 | Loss: 0.05759	 validation loss: 0.5265
Epoch:  78 | Loss: 0.05539	 validation loss: 0.5258
Epoch:  79 | Loss: 0.05333	 validation loss: 0.526
Epoch:  80 | Loss: 0.05147	 validation loss: 0.5259
Epoch:  81 | Loss: 0.05078	 validation loss: 0.5268
Epoch:  82 | Loss: 0.05021	 validation loss: 0.5269
Epoch:  83 | Loss: 0.04482	 validation loss: 0.5269
Epoch:  84 | Loss: 0.03767	 validation loss: 0.5261
Epoch:  85 | Loss: 0.03912	 validation loss: 0.5255
Epoch:  86 | Loss: 0.03442	 validation loss: 0.5269
Epoch:  87 | Loss: 0.03062	 validation loss: 0.5265
Epoch:  88 | Loss: 0.02986	 validation loss: 0.5262
Epoch:  89 | Loss: 0.02875	 validation loss: 0.5262
Epoch:  90 | Loss: 0.02635	 validation loss: 0.5279
Epoch:  91 | Loss: 0.02216	 validation loss: 0.5277
Epoch:  92 | Loss: 0.01943	 validation loss: 0.5281
Epoch:  93 | Loss: 0.01679	 validation loss: 0.5282
Epoch:  94 | Loss: 0.01547	 validation loss: 0.5303
Epoch:  95 | Loss: 0.01206	 validation loss: 0.5328
Epoch:  96 | Loss: 0.01155	 validation loss: 0.5361
Epoch:  97 | Loss: 0.009801	 validation loss: 0.5403
Epoch:  98 | Loss: 0.007907	 validation loss: 0.5395
Epoch:  99 | Loss: 0.005958	 validation loss: 0.5401
CPU times: user 5min 37s, sys: 1min 38s, total: 7min 15s
Wall time: 7min 6s

In [15]:
%%time
tf.reset_default_graph()
# Load model and use a global session
sess = tf.Session(graph=graph)
#loader = tf.train.import_meta_graph(save_dir + '.meta')
saver.restore(sess, saver.last_checkpoints[-1])


INFO:tensorflow:Restoring parameters from ./model/seq2vec
CPU times: user 716 ms, sys: 192 ms, total: 908 ms
Wall time: 641 ms
%%time tf.reset_default_graph() # Load model and use a global session sess = tf.Session() loader = tf.train.import_meta_graph(save_dir + '.meta') loader.restore(sess, save_dir)

In [16]:
def get_word2vec(word, sess=sess):
    pad_input_batch=get_padded_int_batch([word])
    #print(len(pad_input_batch[0]))
    _, outputs = sess.run([initial_state, output],{
        input_: pad_input_batch,
        keep_prob: 1.0,
        input_sequence_length: [len(pad_input_batch[0])],
    })
    return outputs[0]

In [17]:
#word=train_input[200]

word="kapitan"

wordvec=get_word2vec(word)

print(word)
w2v.wv.similar_by_vector(wordvec)


kapitan
Out[17]:
[('konfabulacje', 0.8141539692878723),
 ('samochodzik-zabaweczka', 0.8116906881332397),
 ('wykrad', 0.8109175562858582),
 ('bio-organiczny', 0.8095035552978516),
 ('skrotu', 0.8080761432647705),
 ('geniejalny', 0.8075656294822693),
 ('dtf', 0.8071662783622742),
 ('panteizm', 0.8054572343826294),
 ('zwyrolem', 0.8048051595687866),
 ('mier', 0.8045909404754639)]

In [18]:
get_word2vec("ula")
w2v.wv.index2word[990:1010]


Out[18]:
['takiej',
 'historii',
 'cel',
 'około',
 'siostra',
 'tom',
 'wolno',
 'złego',
 'myśl',
 'moment',
 'wieku',
 'boli',
 'f',
 'spróbować',
 'mimo',
 'niby',
 'prawdziwy',
 'kapitan',
 'przyjacielem',
 'czyż']

In [ ]: