LSTM text generation from Nietzsche's writings

The original script is here. It has the following message regarding speed:

At least 20 epochs are required before the generated text starts sounding coherent. It is recommended to run this script on GPU, as recurrent networks are quite computationally intensive. If you try this script on new data, make sure your corpus has at least ~100k characters. ~1M is better.


In [1]:
# Imports
from __future__ import print_function
from keras.models import Sequential
from keras.layers import Dense, Activation, Dropout
from keras.layers import LSTM
from keras.utils.data_utils import get_file
import numpy as np
import random
import sys


Using Theano backend.

Get the data


In [2]:
# Get the data
path = get_file('nietzsche.txt', origin="https://s3.amazonaws.com/text-datasets/nietzsche.txt")
text = open(path).read().lower()
print('corpus length:', len(text))

chars = set(text)
print('total chars:', len(chars))
char_indices = dict((c, i) for i, c in enumerate(chars))
indices_char = dict((i, c) for i, c in enumerate(chars))

# cut the text in semi-redundant sequences of maxlen characters
maxlen = 40
step = 3
sentences = []
next_chars = []
for i in range(0, len(text) - maxlen, step):
    sentences.append(text[i: i + maxlen])
    next_chars.append(text[i + maxlen])
print('nb sequences:', len(sentences))

print('Vectorization...')
X = np.zeros((len(sentences), maxlen, len(chars)), dtype=np.bool)
y = np.zeros((len(sentences), len(chars)), dtype=np.bool)
for i, sentence in enumerate(sentences):
    for t, char in enumerate(sentence):
        X[i, t, char_indices[char]] = 1
    y[i, char_indices[next_chars[i]]] = 1


Downloading data from https://s3.amazonaws.com/text-datasets/nietzsche.txt
581632/600901 [============================>.] - ETA: 0scorpus length: 600893
total chars: 57
nb sequences: 200285
Vectorization...

Build the neural network


In [3]:
# build the model: 2 stacked LSTM
print('Build model...')
model = Sequential()
model.add(LSTM(512, return_sequences=True, input_shape=(maxlen, len(chars))))
model.add(Dropout(0.2))
model.add(LSTM(512, return_sequences=False))
model.add(Dropout(0.2))
model.add(Dense(len(chars)))
model.add(Activation('softmax'))

model.compile(loss='categorical_crossentropy', optimizer='rmsprop')


Build model...

Train the network and output some text at each step


In [4]:
def sample(a, temperature=1.0):
    # helper function to sample an index from a probability array
    a = np.log(a) / temperature
    a = np.exp(a) / np.sum(np.exp(a))
    return np.argmax(np.random.multinomial(1, a, 1))

# train the model, output generated text after each iteration
for iteration in range(1, 60):
    print()
    print('-' * 50)
    print('Iteration', iteration)
    model.fit(X, y, batch_size=128, nb_epoch=1)

    start_index = random.randint(0, len(text) - maxlen - 1)

    for diversity in [0.2, 0.5, 1.0, 1.2]:
        print()
        print('----- diversity:', diversity)

        generated = ''
        sentence = text[start_index: start_index + maxlen]
        generated += sentence
        print('----- Generating with seed: "' + sentence + '"')
        sys.stdout.write(generated)

        for i in range(400):
            x = np.zeros((1, maxlen, len(chars)))
            for t, char in enumerate(sentence):
                x[0, t, char_indices[char]] = 1.

            preds = model.predict(x, verbose=0)[0]
            next_index = sample(preds, diversity)
            next_char = indices_char[next_index]

            generated += next_char
            sentence = sentence[1:] + next_char

            sys.stdout.write(next_char)
            sys.stdout.flush()
        print()


--------------------------------------------------
Iteration 1
Epoch 1/1
 12672/200285 [>.............................] - ETA: 2926s - loss: 3.1116
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-4-9c0ee3d83b4e> in <module>()
     10     print('-' * 50)
     11     print('Iteration', iteration)
---> 12     model.fit(X, y, batch_size=128, nb_epoch=1)
     13 
     14     start_index = random.randint(0, len(text) - maxlen - 1)

/Users/kgullikson/anaconda/lib/python3.5/site-packages/Keras-1.0.4-py3.5.egg/keras/models.py in fit(self, x, y, batch_size, nb_epoch, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, **kwargs)
    406                               shuffle=shuffle,
    407                               class_weight=class_weight,
--> 408                               sample_weight=sample_weight)
    409 
    410     def evaluate(self, x, y, batch_size=32, verbose=1,

/Users/kgullikson/anaconda/lib/python3.5/site-packages/Keras-1.0.4-py3.5.egg/keras/engine/training.py in fit(self, x, y, batch_size, nb_epoch, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight)
   1067                               verbose=verbose, callbacks=callbacks,
   1068                               val_f=val_f, val_ins=val_ins, shuffle=shuffle,
-> 1069                               callback_metrics=callback_metrics)
   1070 
   1071     def evaluate(self, x, y, batch_size=32, verbose=1, sample_weight=None):

/Users/kgullikson/anaconda/lib/python3.5/site-packages/Keras-1.0.4-py3.5.egg/keras/engine/training.py in _fit_loop(self, f, ins, out_labels, batch_size, nb_epoch, verbose, callbacks, val_f, val_ins, shuffle, callback_metrics)
    789                 batch_logs['size'] = len(batch_ids)
    790                 callbacks.on_batch_begin(batch_index, batch_logs)
--> 791                 outs = f(ins_batch)
    792                 if type(outs) != list:
    793                     outs = [outs]

/Users/kgullikson/anaconda/lib/python3.5/site-packages/Keras-1.0.4-py3.5.egg/keras/backend/theano_backend.py in __call__(self, inputs)
    519     def __call__(self, inputs):
    520         assert type(inputs) in {list, tuple}
--> 521         return self.function(*inputs)
    522 
    523 

/Users/kgullikson/anaconda/lib/python3.5/site-packages/theano/compile/function_module.py in __call__(self, *args, **kwargs)
    857         t0_fn = time.time()
    858         try:
--> 859             outputs = self.fn()
    860         except Exception:
    861             if hasattr(self.fn, 'position_of_error'):

/Users/kgullikson/anaconda/lib/python3.5/site-packages/theano/scan_module/scan_op.py in rval(p, i, o, n, allow_gc)
    949         def rval(p=p, i=node_input_storage, o=node_output_storage, n=node,
    950                  allow_gc=allow_gc):
--> 951             r = p(n, [x[0] for x in i], o)
    952             for o in node.outputs:
    953                 compute_map[o][0] = True

/Users/kgullikson/anaconda/lib/python3.5/site-packages/theano/scan_module/scan_op.py in <lambda>(node, args, outs)
    938                         args,
    939                         outs,
--> 940                         self, node)
    941         except (ImportError, theano.gof.cmodule.MissingGXX):
    942             p = self.execute

KeyboardInterrupt: 

In [ ]: