In [1]:
import os
import time
from six.moves import cPickle
import numpy as np
import tensorflow as tf
In [2]:
load_dir = "data/"
load_name = os.path.join(load_dir, 'chars_vocab.pkl')
with open(load_name, 'rb') as fload:
chars, vocab = cPickle.load(fload)
load_name = os.path.join(load_dir, 'corpus_data.pkl')
with open(load_name, 'rb') as fload:
corpus, data = cPickle.load(fload)
In [3]:
batch_size = 1
seq_length = 1
vocab_size = len(vocab)
rnn_size = 128
num_layers = 2
grad_clip = 5.
unitcell = tf.nn.rnn_cell.BasicLSTMCell(rnn_size)
cell = tf.nn.rnn_cell.MultiRNNCell([unitcell] * num_layers)
input_data = tf.placeholder(tf.int32, [batch_size, seq_length])
targets = tf.placeholder(tf.int32, [batch_size, seq_length])
istate = cell.zero_state(batch_size, tf.float32)
with tf.variable_scope('rnnlm') as scope:
try:
softmax_w = tf.get_variable("softmax_w", [rnn_size, vocab_size])
softmax_b = tf.get_variable("softmax_b", [vocab_size])
except ValueError:
scope.reuse_variables()
softmax_w = tf.get_variable("softmax_w", [rnn_size, vocab_size])
softmax_b = tf.get_variable("softmax_b", [vocab_size])
embedding = tf.get_variable("embedding", [vocab_size, rnn_size])
inputs = tf.split(1, seq_length, tf.nn.embedding_lookup(embedding, input_data))
inputs = [tf.squeeze(_input, [1]) for _input in inputs]
outputs, last_state = tf.nn.seq2seq.rnn_decoder(inputs, istate, cell
, loop_function=None, scope='rnnlm')
output = tf.reshape(tf.concat(1, outputs), [-1, rnn_size])
logits = tf.nn.xw_plus_b(output, softmax_w, softmax_b)
probs = tf.nn.softmax(logits)
loss = tf.nn.seq2seq.sequence_loss_by_example([logits], # Input
[tf.reshape(targets, [-1])], # Target
[tf.ones([batch_size * seq_length])], # Weight
vocab_size)
cost = tf.reduce_sum(loss) / batch_size / seq_length
final_state = last_state
lr = tf.Variable(0.0, trainable=False)
tvars = tf.trainable_variables()
grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars), grad_clip)
_optm = tf.train.AdamOptimizer(lr)
optm = _optm.apply_gradients(zip(grads, tvars))
In [4]:
sess = tf.Session()
sess.run(tf.initialize_all_variables())
saver = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state(load_dir)
saver.restore(sess, ckpt.model_checkpoint_path)
In [5]:
def weighted_pick(weights):
t = np.cumsum(weights)
s = np.sum(weights)
return(int(np.searchsorted(t, np.random.rand(1)*s)))
prime = "START"
state = sess.run(cell.zero_state(1, tf.float32))
x = np.zeros((1, 1))
x[0, 0] = 1.0
state = sess.run(final_state, feed_dict={input_data: x, istate:state})
ret = prime
char = prime[-1]
num = 10
for n in range(num):
x = np.zeros((1, 1))
x[0, 0] = 1
[probsval, state] = sess.run([probs, final_state], feed_dict={input_data: x, istate:state})
p = probsval[0]
sample = weighted_pick(p)
pred = chars[sample]
ret = ret + " " + pred
char = pred
print (ret)
In [ ]: