In [1]:
# -*- coding: utf-8 -*-
# Import Packages
import numpy as np
import tensorflow as tf
import collections
import string
import argparse
import time
import os
from six.moves import cPickle
from TextLoader import *
from Hangulpy import *
print ("Packages Imported")
In [2]:
corpus_name = "invisible_dragon" # "nine_dreams"
# corpus_name = "nine_dreams"
data_dir = "data/" + corpus_name
batch_size = 10
seq_length = 100
data_loader = TextLoader(data_dir, batch_size, seq_length)
# This makes "vocab.pkl" and "data.npy" in "data/nine_dreams"
# from "data/nine_dreams/input.txt"
vocab_size = data_loader.vocab_size
vocab = data_loader.vocab
chars = data_loader.chars
print ( "type of 'data_loader' is %s, length is %d"
% (type(data_loader.vocab), len(data_loader.vocab)) )
print ( "\n" )
print ("data_loader.vocab looks like \n%s " %
(data_loader.vocab))
print ( "\n" )
print ( "type of 'data_loader.chars' is %s, length is %d"
% (type(data_loader.chars), len(data_loader.chars)) )
print ( "\n" )
print ("data_loader.chars looks like \n%s " % (data_loader.chars,))
In [3]:
rnn_size = 128
num_layers = 2
grad_clip = 5.
_batch_size = 1
_seq_length = 1
vocab_size = data_loader.vocab_size
with tf.device("/cpu:0"):
# Select RNN Cell
def unit_cell():
return tf.contrib.rnn.BasicLSTMCell(rnn_size,state_is_tuple=True,reuse=tf.get_variable_scope().reuse)
cell = tf.contrib.rnn.MultiRNNCell([unit_cell() for _ in range(num_layers)])
# Set paths to the graph
input_data = tf.placeholder(tf.int32, [_batch_size, _seq_length])
targets = tf.placeholder(tf.int32, [_batch_size, _seq_length])
initial_state = cell.zero_state(_batch_size, tf.float32)
# Set Network
with tf.variable_scope('rnnlm'):
softmax_w = tf.get_variable("softmax_w", [rnn_size, vocab_size])
softmax_b = tf.get_variable("softmax_b", [vocab_size])
with tf.device("/cpu:0"):
embedding = tf.get_variable("embedding", [vocab_size, rnn_size])
inputs = tf.split(tf.nn.embedding_lookup(embedding, input_data), _seq_length, 1)
inputs = [tf.squeeze(input_, [1]) for input_ in inputs]
# Loop function for seq2seq
def loop(prev, _):
prev = tf.nn.xw_plus_b(prev, softmax_w, softmax_b)
prev_symbol = tf.stop_gradient(tf.argmax(prev, 1))
return tf.nn.embedding_lookup(embedding, prev_symbol)
# Output of RNN
outputs, last_state = tf.contrib.rnn.static_rnn(cell,inputs, initial_state
, scope='rnnlm')
output = tf.reshape(tf.concat(outputs,1), [-1, rnn_size])
logits = tf.nn.xw_plus_b(output, softmax_w, softmax_b)
# Next word probability
probs = tf.nn.softmax(logits)
final_state = last_state
print ("Network Ready")
In [4]:
# Sample !
def sample( sess, chars, vocab, __probs, num=200, prime=u'ㅇㅗᴥㄴㅡㄹᴥ '):
state = sess.run(cell.zero_state(1, tf.float32))
_probs = __probs
prime = list(prime)
for char in prime[:-1]:
x = np.zeros((1, 1))
x[0, 0] = vocab[char]
feed = {input_data: x, initial_state:state}
[state] = sess.run([last_state], feed)
def weighted_pick(weights):
weights = weights / np.sum(weights)
t = np.cumsum(weights)
s = np.sum(weights)
return(int(np.searchsorted(t, np.random.rand(1)*s)))
ret = prime
char = prime[-1]
for n in range(num):
x = np.zeros((1, 1))
x[0, 0] = vocab[char]
feed = {input_data: x, initial_state:state}
[_probsval, state] = sess.run([_probs, final_state], feed)
p = _probsval[0]
sample = int(np.random.choice(len(p), p=p))
#sample = weighted_pick(p)
pred = chars[sample]
ret += pred
char = pred
return ret
print ("sampling function done.")
In [7]:
save_dir = 'data/' + corpus_name
prime = decompose_text(u" ")
print ("Prime Text : %s => %s" % (automata(prime), "".join(prime)))
n = 4000
sess = tf.Session()
sess.run(tf.initialize_all_variables())
saver = tf.train.Saver(tf.all_variables())
ckpt = tf.train.get_checkpoint_state(save_dir)
# load_name = u'data/nine_dreams/model.ckpt-0'
load_name = os.path.join(save_dir,'model.ckpt-20000')
print (load_name)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, load_name)
sampled_text = sample(sess, chars, vocab, probs, n, prime)
#print ("")
# print (u"SAMPLED TEXT = %s" % sampled_text)
# print ("")
print ("-- RESULT --")
print (automata("".join(sampled_text)))
It takes long time to train!
In [ ]: