Use RNN to produce poetry

In this note, I am going to use RNN to build a model that produce poetry. This is a code practice from Udemy deep learning course.

Prepare Data

The function below will process robert_frost.txt and transform it into format required for the model training. remove_punctuation removes punctuation as it's not required in our model. get_robert_frost parse on each line and collect each word into a word-index dictionary. The ourput will be a list with all sentences where each word has been transformed into index.


In [1]:
import string
import theano
import theano.tensor as T
import numpy as np
import matplotlib.pyplot as plt

from sklearn.utils import shuffle

def init_weight(Mi, Mo):
    return np.random.randn(Mi, Mo) / np.sqrt(Mi + Mo)

def remove_punctuation(s):
    translator = str.maketrans({key: None for key in string.punctuation})
    return s.translate(translator)

def get_robert_frost():
    word2idx = {'START': 0, 'END': 1}
    current_idx = 2
    sentences = []
    for line in open('../data/robert_frost.txt'):
        line = line.strip()
        if line:
            tokens = remove_punctuation(line.lower()).split()
            sentence = []
            for t in tokens:
                if t not in word2idx:
                    # true means it's a new word for our word2idx dictionary.
                    # add to dictionary if not exists and assign word index for it.
                    word2idx[t] = current_idx
                    current_idx += 1
                idx = word2idx[t]
                sentence.append(idx)  # transform word into index.
            sentences.append(sentence)
    return sentences, word2idx

In [2]:
class SimpleRNN:
    def __init__(self, D, M, V):
        self.D = D # dim of wrod embedding
        self.M = M # hidden layer size
        self.V = V # vocabulary size
        
    def fit(self, X, learning_rate=10e-1, mu=0.99, reg=1.0, activation=T.tanh, epochs=500, show_fig=False):
        N = len(X) # numbers of sentences
        D = self.D
        M = self.M
        V = self.V
        self.f = activation # for hidden layer

        # initial weights
        We = init_weight(V, D)
        Wx = init_weight(D, M)
        Wh = init_weight(M, M)
        bh = np.zeros(M)
        h0 = np.zeros(M)
        Wo = init_weight(M, V)
        bo = np.zeros(V)

        # make them theano shared
        self.We = theano.shared(We)
        self.Wx = theano.shared(Wx)
        self.Wh = theano.shared(Wh)
        self.bh = theano.shared(bh)
        self.h0 = theano.shared(h0)
        self.Wo = theano.shared(Wo)
        self.bo = theano.shared(bo)
        self.params = [self.We, self.Wx, self.Wh, self.bh, self.h0, self.Wo, self.bo]
        
        # sentence input:
        # [START, w1, w2, ..., wn]
        # sentence target:
        # [w1,    w2, w3, ..., END]
        thX = T.ivector('X') # the sequence. will have length T. Note each sequence will have different T
        Ei = self.We[thX] # this will be a TxD matrix
        thY = T.ivector('Y')
        
        def recurrence(x_t, h_t1):
            # return h(t), y(t)
            h_t = self.f(x_t.dot(Wx) + h_t1.dot(self.Wh) + self.bh)
            y_t = T.nnet.softmax(h_t.dot(self.Wo) + self.bo)
            return h_t, y_t
        
        [h, y], _ = theano.scan(
            fn=recurrence,
            output_info=[self.h0, None],
            sequences=Ei,
            n_steps=Ei.shape[0],
        )

In [3]:
import numpy as np
import theano
import theano.tensor as T


N = T.iscalar('N')

def recurrence(n, fn_1, fn_2): # Theano will know there're 2 recursive parameters.
    fn_t = fn_1 + fn_2
    # return current and last
    return fn_t, fn_1  # As Theano knows there're 2 recursive parameters, both will be used for next iteration.

outputs, _ = theano.scan(
    fn=recurrence,
    n_steps=N,
    sequences=T.arange(N), # if remove this or set as sequences=[], n argument in recurrence() needs to be removed.
    outputs_info=[1., 1.]  # must be a list and has the same lenght as output of fn.
)

fibonacci = theano.function(
    inputs=[N],
    outputs=outputs[0],
)

o_val = fibonacci(8)

print("output:", o_val)


output: [  2.   3.   5.   8.  13.  21.  34.  55.]