In [3]:
import os
import pandas as pd
import numpy as np
from keras.models import Model
from ch10 import construct_seq2seq_model
from nlpia.loaders import get_data, DATA_PATH

In [11]:
batch_size = 64  # Batch size for training.
epochs = 100  # Number of epochs to train for.
num_samples = 10000
data_path = os.path.join(DATA_PATH, 'movie_dialog.txt')  # preprocessed CMU movie dialogue samples

In [12]:
try:
    import cPickle as pickle
except ImportError:
    import pickle

from io import open

with open("../data/characters_stats.pkl", "rb") as filehandler:
    input_characters, target_characters, input_token_index, target_token_index = pickle.load(filehandler)

with open("../data/encoder_decoder_stats.pkl", "rb") as filehandler:
    num_encoder_tokens, num_decoder_tokens, max_encoder_seq_length, max_decoder_seq_length = pickle.load(filehandler)

In [13]:
input_texts = []
target_texts = []
input_characters = set()
target_characters = set()
lines = open(data_path).read().split('\n')
for line in lines[: min(num_samples, len(lines) - 1)]:
    input_text, target_text = line.split('\t')
    # We use "tab" as the "start sequence" character
    # for the targets, and "\n" as "end sequence" character.
    target_text = '\t' + target_text + '\n'
    input_texts.append(input_text)
    target_texts.append(target_text)
    for char in input_text:
        if char not in input_characters:
            input_characters.add(char)
    for char in target_text:
        if char not in target_characters:
            target_characters.add(char)

input_characters = sorted(list(input_characters))
target_characters = sorted(list(target_characters))
num_encoder_tokens = len(input_characters)
num_decoder_tokens = len(target_characters)
max_encoder_seq_length = max([len(txt) for txt in input_texts])
max_decoder_seq_length = max([len(txt) for txt in target_texts])

print('Number of samples:', len(input_texts))
print('Number of unique input tokens:', num_encoder_tokens)
print('Number of unique output tokens:', num_decoder_tokens)
print('Max sequence length for inputs:', max_encoder_seq_length)
print('Max sequence length for outputs:', max_decoder_seq_length)


Number of samples: 10000
Number of unique input tokens: 44
Number of unique output tokens: 46
Max sequence length for inputs: 100
Max sequence length for outputs: 102

In [14]:
encoder_input_data = np.zeros(
    (len(input_texts), max_encoder_seq_length, num_encoder_tokens),
    dtype='float32')
decoder_input_data = np.zeros(
    (len(input_texts), max_decoder_seq_length, num_decoder_tokens),
    dtype='float32')
decoder_target_data = np.zeros(
    (len(input_texts), max_decoder_seq_length, num_decoder_tokens),
    dtype='float32')

In [15]:
for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)):
    for t, char in enumerate(input_text):
        encoder_input_data[i, t, input_token_index[char]] = 1.
    for t, char in enumerate(target_text):
        # decoder_target_data is ahead of decoder_input_data by one timestep
        decoder_input_data[i, t, target_token_index[char]] = 1.
        if t > 0:
            # decoder_target_data will be ahead by one timestep
            # and will not include the start character.
            decoder_target_data[i, t - 1, target_token_index[char]] = 1.

In [32]:
# model = construct_seq2seq_model(num_encoder_tokens, num_decoder_tokens)
from keras.layers import Input, LSTM, Dense
batch_size = 64    # <1>
epochs = 100       # <2>
num_neurons = 256  # <3>

encoder_inputs = Input(shape=(None, num_encoder_tokens))
encoder = LSTM(num_neurons, return_state=True)
_, state_h, state_c = encoder(encoder_inputs)

encoder_states = [state_h, state_c]

decoder_inputs = Input(shape=(None, num_decoder_tokens))
decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True)
decoder_outputs, _, _ = decoder_lstm(decoder_inputs,
                                     initial_state=encoder_states)
decoder_dense = Dense(num_decoder_tokens, activation='softmax')
decoder_outputs = decoder_dense(decoder_outputs)

model = Model([encoder_inputs, decoder_inputs], decoder_outputs)

In [8]:
# Run training
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['acc'])
model.fit([encoder_input_data, decoder_input_data], decoder_target_data,
          batch_size=batch_size,
          epochs=epochs,
          validation_split=0.2)


Train on 8000 samples, validate on 2000 samples
Epoch 1/100
8000/8000 [==============================] - 215s 27ms/step - loss: 1.0202 - acc: 0.0647 - val_loss: 0.8652 - val_acc: 0.0855
Epoch 2/100
8000/8000 [==============================] - 209s 26ms/step - loss: 0.8518 - acc: 0.1027 - val_loss: 0.7556 - val_acc: 0.1125
Epoch 3/100
8000/8000 [==============================] - 209s 26ms/step - loss: 0.7695 - acc: 0.1197 - val_loss: 0.7000 - val_acc: 0.1200
Epoch 4/100
8000/8000 [==============================] - 228s 28ms/step - loss: 0.7238 - acc: 0.1303 - val_loss: 0.6640 - val_acc: 0.1306
Epoch 5/100
8000/8000 [==============================] - 227s 28ms/step - loss: 0.6899 - acc: 0.1391 - val_loss: 0.6448 - val_acc: 0.1364
Epoch 6/100
8000/8000 [==============================] - 231s 29ms/step - loss: 0.6647 - acc: 0.1464 - val_loss: 0.6160 - val_acc: 0.1434
Epoch 7/100
8000/8000 [==============================] - 232s 29ms/step - loss: 0.6423 - acc: 0.1515 - val_loss: 0.6016 - val_acc: 0.1468
Epoch 8/100
8000/8000 [==============================] - 417s 52ms/step - loss: 0.6233 - acc: 0.1560 - val_loss: 0.5868 - val_acc: 0.1506
Epoch 9/100
8000/8000 [==============================] - 422s 53ms/step - loss: 0.6069 - acc: 0.1608 - val_loss: 0.5747 - val_acc: 0.1527
Epoch 10/100
8000/8000 [==============================] - 403s 50ms/step - loss: 0.5927 - acc: 0.1645 - val_loss: 0.5641 - val_acc: 0.1568
Epoch 11/100
8000/8000 [==============================] - 404s 50ms/step - loss: 0.5798 - acc: 0.1681 - val_loss: 0.5554 - val_acc: 0.1579
Epoch 12/100
8000/8000 [==============================] - 234s 29ms/step - loss: 0.5688 - acc: 0.1714 - val_loss: 0.5514 - val_acc: 0.1581
Epoch 13/100
8000/8000 [==============================] - 199s 25ms/step - loss: 0.5583 - acc: 0.1739 - val_loss: 0.5411 - val_acc: 0.1628
Epoch 14/100
8000/8000 [==============================] - 213s 27ms/step - loss: 0.5499 - acc: 0.1764 - val_loss: 0.5385 - val_acc: 0.1630
Epoch 15/100
8000/8000 [==============================] - 196s 24ms/step - loss: 0.5403 - acc: 0.1792 - val_loss: 0.5307 - val_acc: 0.1652
Epoch 16/100
8000/8000 [==============================] - 196s 25ms/step - loss: 0.5320 - acc: 0.1811 - val_loss: 0.5274 - val_acc: 0.1659
Epoch 17/100
8000/8000 [==============================] - 203s 25ms/step - loss: 0.5244 - acc: 0.1831 - val_loss: 0.5251 - val_acc: 0.1661
Epoch 18/100
8000/8000 [==============================] - 250s 31ms/step - loss: 0.5170 - acc: 0.1854 - val_loss: 0.5220 - val_acc: 0.1678
Epoch 19/100
8000/8000 [==============================] - 203s 25ms/step - loss: 0.5102 - acc: 0.1874 - val_loss: 0.5208 - val_acc: 0.1681
Epoch 20/100
8000/8000 [==============================] - 198s 25ms/step - loss: 0.5036 - acc: 0.1891 - val_loss: 0.5183 - val_acc: 0.1684
Epoch 21/100
8000/8000 [==============================] - 209s 26ms/step - loss: 0.5064 - acc: 0.1893 - val_loss: 0.5160 - val_acc: 0.1698
Epoch 22/100
8000/8000 [==============================] - 224s 28ms/step - loss: 0.4918 - acc: 0.1924 - val_loss: 0.5148 - val_acc: 0.1698
Epoch 23/100
8000/8000 [==============================] - 235s 29ms/step - loss: 0.4858 - acc: 0.1942 - val_loss: 0.5147 - val_acc: 0.1705
Epoch 24/100
8000/8000 [==============================] - 209s 26ms/step - loss: 0.4800 - acc: 0.1959 - val_loss: 0.5141 - val_acc: 0.1702
Epoch 25/100
8000/8000 [==============================] - 232s 29ms/step - loss: 0.4746 - acc: 0.1974 - val_loss: 0.5151 - val_acc: 0.1695
Epoch 26/100
8000/8000 [==============================] - 249s 31ms/step - loss: 0.4690 - acc: 0.1990 - val_loss: 0.5163 - val_acc: 0.1703
Epoch 27/100
8000/8000 [==============================] - 212s 26ms/step - loss: 0.4636 - acc: 0.2006 - val_loss: 0.5158 - val_acc: 0.1701
Epoch 28/100
8000/8000 [==============================] - 235s 29ms/step - loss: 0.4676 - acc: 0.2000 - val_loss: 0.5173 - val_acc: 0.1701
Epoch 29/100
8000/8000 [==============================] - 254s 32ms/step - loss: 0.4545 - acc: 0.2032 - val_loss: 0.5191 - val_acc: 0.1699
Epoch 30/100
8000/8000 [==============================] - 240s 30ms/step - loss: 0.4512 - acc: 0.2045 - val_loss: 0.5185 - val_acc: 0.1704
Epoch 31/100
8000/8000 [==============================] - 234s 29ms/step - loss: 0.4463 - acc: 0.2060 - val_loss: 0.5198 - val_acc: 0.1702
Epoch 32/100
8000/8000 [==============================] - 233s 29ms/step - loss: 0.4415 - acc: 0.2072 - val_loss: 0.5234 - val_acc: 0.1695
Epoch 33/100
8000/8000 [==============================] - 239s 30ms/step - loss: 0.4356 - acc: 0.2089 - val_loss: 0.5246 - val_acc: 0.1690
Epoch 34/100
8000/8000 [==============================] - 234s 29ms/step - loss: 0.4300 - acc: 0.2107 - val_loss: 0.5280 - val_acc: 0.1691
Epoch 35/100
8000/8000 [==============================] - 256s 32ms/step - loss: 0.4249 - acc: 0.2123 - val_loss: 0.5302 - val_acc: 0.1678
Epoch 36/100
8000/8000 [==============================] - 211s 26ms/step - loss: 0.4200 - acc: 0.2136 - val_loss: 0.5310 - val_acc: 0.1683
Epoch 37/100
8000/8000 [==============================] - 208s 26ms/step - loss: 0.4153 - acc: 0.2151 - val_loss: 0.5342 - val_acc: 0.1689
Epoch 38/100
8000/8000 [==============================] - 208s 26ms/step - loss: 0.4104 - acc: 0.2167 - val_loss: 0.5369 - val_acc: 0.1683
Epoch 39/100
8000/8000 [==============================] - 209s 26ms/step - loss: 0.4062 - acc: 0.2183 - val_loss: 0.5404 - val_acc: 0.1673
Epoch 40/100
8000/8000 [==============================] - 208s 26ms/step - loss: 0.4011 - acc: 0.2197 - val_loss: 0.5435 - val_acc: 0.1665
Epoch 41/100
8000/8000 [==============================] - 211s 26ms/step - loss: 0.3967 - acc: 0.2208 - val_loss: 0.5465 - val_acc: 0.1669
Epoch 42/100
8000/8000 [==============================] - 210s 26ms/step - loss: 0.3925 - acc: 0.2224 - val_loss: 0.5504 - val_acc: 0.1664
Epoch 43/100
8000/8000 [==============================] - 211s 26ms/step - loss: 0.3881 - acc: 0.2236 - val_loss: 0.5536 - val_acc: 0.1662
Epoch 44/100
8000/8000 [==============================] - 211s 26ms/step - loss: 0.3838 - acc: 0.2250 - val_loss: 0.5579 - val_acc: 0.1653
Epoch 45/100
8000/8000 [==============================] - 210s 26ms/step - loss: 0.3796 - acc: 0.2262 - val_loss: 0.5600 - val_acc: 0.1657
Epoch 46/100
8000/8000 [==============================] - 211s 26ms/step - loss: 0.3756 - acc: 0.2276 - val_loss: 0.5645 - val_acc: 0.1649
Epoch 47/100
8000/8000 [==============================] - 211s 26ms/step - loss: 0.3720 - acc: 0.2287 - val_loss: 0.5678 - val_acc: 0.1649
Epoch 48/100
8000/8000 [==============================] - 209s 26ms/step - loss: 0.3677 - acc: 0.2302 - val_loss: 0.5715 - val_acc: 0.1650
Epoch 49/100
8000/8000 [==============================] - 212s 27ms/step - loss: 0.3639 - acc: 0.2316 - val_loss: 0.5756 - val_acc: 0.1641
Epoch 50/100
8000/8000 [==============================] - 211s 26ms/step - loss: 0.3603 - acc: 0.2325 - val_loss: 0.5816 - val_acc: 0.1632
Epoch 51/100
8000/8000 [==============================] - 210s 26ms/step - loss: 0.3567 - acc: 0.2332 - val_loss: 0.5858 - val_acc: 0.1624
Epoch 52/100
8000/8000 [==============================] - 214s 27ms/step - loss: 0.3903 - acc: 0.2270 - val_loss: 0.5843 - val_acc: 0.1634
Epoch 53/100
8000/8000 [==============================] - 208s 26ms/step - loss: 0.3545 - acc: 0.2341 - val_loss: 0.5901 - val_acc: 0.1623
Epoch 54/100
8000/8000 [==============================] - 212s 26ms/step - loss: 0.3482 - acc: 0.2361 - val_loss: 0.5913 - val_acc: 0.1622
Epoch 55/100
8000/8000 [==============================] - 210s 26ms/step - loss: 0.3748 - acc: 0.2295 - val_loss: 0.6049 - val_acc: 0.1599
Epoch 56/100
8000/8000 [==============================] - 243s 30ms/step - loss: 0.3662 - acc: 0.2316 - val_loss: 0.5963 - val_acc: 0.1619
Epoch 57/100
8000/8000 [==============================] - 224s 28ms/step - loss: 0.3423 - acc: 0.2380 - val_loss: 0.5986 - val_acc: 0.1620
Epoch 58/100
8000/8000 [==============================] - 210s 26ms/step - loss: 0.3385 - acc: 0.2391 - val_loss: 0.6029 - val_acc: 0.1621
Epoch 59/100
8000/8000 [==============================] - 210s 26ms/step - loss: 0.3355 - acc: 0.2402 - val_loss: 0.6075 - val_acc: 0.1617
Epoch 60/100
8000/8000 [==============================] - 211s 26ms/step - loss: 0.3323 - acc: 0.2410 - val_loss: 0.6124 - val_acc: 0.1608
Epoch 61/100
8000/8000 [==============================] - 210s 26ms/step - loss: 0.3296 - acc: 0.2418 - val_loss: 0.6131 - val_acc: 0.1616
Epoch 62/100
8000/8000 [==============================] - 210s 26ms/step - loss: 0.3268 - acc: 0.2427 - val_loss: 0.6171 - val_acc: 0.1612
Epoch 63/100
8000/8000 [==============================] - 211s 26ms/step - loss: 0.3243 - acc: 0.2435 - val_loss: 0.6229 - val_acc: 0.1597
Epoch 64/100
8000/8000 [==============================] - 210s 26ms/step - loss: 0.3222 - acc: 0.2442 - val_loss: 0.6245 - val_acc: 0.1606
Epoch 65/100
8000/8000 [==============================] - 211s 26ms/step - loss: 0.3195 - acc: 0.2448 - val_loss: 0.6269 - val_acc: 0.1601
Epoch 66/100
8000/8000 [==============================] - 214s 27ms/step - loss: 0.3173 - acc: 0.2457 - val_loss: 0.6319 - val_acc: 0.1601
Epoch 67/100
8000/8000 [==============================] - 212s 26ms/step - loss: 0.3152 - acc: 0.2462 - val_loss: 0.6350 - val_acc: 0.1595
Epoch 68/100
8000/8000 [==============================] - 210s 26ms/step - loss: 0.3130 - acc: 0.2469 - val_loss: 0.6387 - val_acc: 0.1601
Epoch 69/100
8000/8000 [==============================] - 214s 27ms/step - loss: 0.3107 - acc: 0.2478 - val_loss: 0.6444 - val_acc: 0.1597
Epoch 70/100
8000/8000 [==============================] - 210s 26ms/step - loss: 0.3085 - acc: 0.2484 - val_loss: 0.6452 - val_acc: 0.1599
Epoch 71/100
8000/8000 [==============================] - 212s 26ms/step - loss: 0.3076 - acc: 0.2485 - val_loss: 0.6487 - val_acc: 0.1596
Epoch 72/100
8000/8000 [==============================] - 228s 28ms/step - loss: 0.3049 - acc: 0.2493 - val_loss: 0.6526 - val_acc: 0.1592
Epoch 73/100
8000/8000 [==============================] - 211s 26ms/step - loss: 0.3037 - acc: 0.2499 - val_loss: 0.6538 - val_acc: 0.1593
Epoch 74/100
8000/8000 [==============================] - 212s 26ms/step - loss: 0.3016 - acc: 0.2503 - val_loss: 0.6566 - val_acc: 0.1591
Epoch 75/100
8000/8000 [==============================] - 209s 26ms/step - loss: 0.3000 - acc: 0.2506 - val_loss: 0.6581 - val_acc: 0.1585
Epoch 76/100
8000/8000 [==============================] - 211s 26ms/step - loss: 0.2991 - acc: 0.2513 - val_loss: 0.6610 - val_acc: 0.1596
Epoch 77/100
8000/8000 [==============================] - 213s 27ms/step - loss: 0.2966 - acc: 0.2519 - val_loss: 0.6668 - val_acc: 0.1584
Epoch 78/100
8000/8000 [==============================] - 211s 26ms/step - loss: 0.2948 - acc: 0.2524 - val_loss: 0.6680 - val_acc: 0.1586
Epoch 79/100
8000/8000 [==============================] - 209s 26ms/step - loss: 0.2934 - acc: 0.2529 - val_loss: 0.6721 - val_acc: 0.1579
Epoch 80/100
8000/8000 [==============================] - 211s 26ms/step - loss: 0.2921 - acc: 0.2531 - val_loss: 0.6754 - val_acc: 0.1584
Epoch 81/100
8000/8000 [==============================] - 211s 26ms/step - loss: 0.2907 - acc: 0.2535 - val_loss: 0.6775 - val_acc: 0.1585
Epoch 82/100
8000/8000 [==============================] - 211s 26ms/step - loss: 0.2892 - acc: 0.2538 - val_loss: 0.6828 - val_acc: 0.1580
Epoch 83/100
8000/8000 [==============================] - 212s 26ms/step - loss: 0.2880 - acc: 0.2543 - val_loss: 0.6824 - val_acc: 0.1577
Epoch 84/100
8000/8000 [==============================] - 210s 26ms/step - loss: 0.2866 - acc: 0.2548 - val_loss: 0.6864 - val_acc: 0.1577
Epoch 85/100
8000/8000 [==============================] - 211s 26ms/step - loss: 0.2848 - acc: 0.2557 - val_loss: 0.6883 - val_acc: 0.1578
Epoch 86/100
8000/8000 [==============================] - 233s 29ms/step - loss: 0.2835 - acc: 0.2557 - val_loss: 0.6874 - val_acc: 0.1573
Epoch 87/100
8000/8000 [==============================] - 210s 26ms/step - loss: 0.2826 - acc: 0.2561 - val_loss: 0.6908 - val_acc: 0.1581
Epoch 88/100
8000/8000 [==============================] - 212s 26ms/step - loss: 0.2808 - acc: 0.2566 - val_loss: 0.6974 - val_acc: 0.1574
Epoch 89/100
8000/8000 [==============================] - 211s 26ms/step - loss: 0.2799 - acc: 0.2567 - val_loss: 0.6991 - val_acc: 0.1575
Epoch 90/100
8000/8000 [==============================] - 211s 26ms/step - loss: 0.2789 - acc: 0.2570 - val_loss: 0.7016 - val_acc: 0.1574
Epoch 91/100
8000/8000 [==============================] - 210s 26ms/step - loss: 0.2781 - acc: 0.2573 - val_loss: 0.7036 - val_acc: 0.1569
Epoch 92/100
8000/8000 [==============================] - 210s 26ms/step - loss: 0.2760 - acc: 0.2581 - val_loss: 0.7062 - val_acc: 0.1571
Epoch 93/100
8000/8000 [==============================] - 211s 26ms/step - loss: 0.2752 - acc: 0.2583 - val_loss: 0.7118 - val_acc: 0.1568
Epoch 94/100
8000/8000 [==============================] - 210s 26ms/step - loss: 0.2744 - acc: 0.2585 - val_loss: 0.7104 - val_acc: 0.1575
Epoch 95/100
8000/8000 [==============================] - 212s 26ms/step - loss: 0.2732 - acc: 0.2588 - val_loss: 0.7117 - val_acc: 0.1569
Epoch 96/100
8000/8000 [==============================] - 209s 26ms/step - loss: 0.2718 - acc: 0.2593 - val_loss: 0.7139 - val_acc: 0.1576
Epoch 97/100
8000/8000 [==============================] - 210s 26ms/step - loss: 0.2707 - acc: 0.2593 - val_loss: 0.7166 - val_acc: 0.1571
Epoch 98/100
8000/8000 [==============================] - 213s 27ms/step - loss: 0.2699 - acc: 0.2597 - val_loss: 0.7189 - val_acc: 0.1569
Epoch 99/100
8000/8000 [==============================] - 211s 26ms/step - loss: 0.2686 - acc: 0.2601 - val_loss: 0.7236 - val_acc: 0.1571
Epoch 100/100
8000/8000 [==============================] - 211s 26ms/step - loss: 0.2686 - acc: 0.2600 - val_loss: 0.7234 - val_acc: 0.1561
Out[8]:
<keras.callbacks.History at 0xb61cdd0b8>

In [9]:
model_path = os.path.join(DATA_PATH, 'ch10_train_seq2seq_keras.h5')
model.save(model_path + '_model')


/Users/hobsonlane/anaconda3/envs/nlpiaenv/lib/python3.6/site-packages/keras/engine/network.py:877: UserWarning: Layer lstm_2 was passed non-serializable keyword arguments: {'initial_state': [<tf.Tensor 'lstm_1/while/Exit_2:0' shape=(?, 256) dtype=float32>, <tf.Tensor 'lstm_1/while/Exit_3:0' shape=(?, 256) dtype=float32>]}. They will not be included in the serialized model (and thus will be missing at deserialization time).
  '. They will not be included '

In [ ]:
model.save_weights(model_path + '_weights.h5')

Model Inference/Activation without training

Everything below can be run without rerunning training
TODO: put these cells in a separate notebook named ch10inference...


In [33]:
from keras.models import load_model

model_path = os.path.join(DATA_PATH, 'ch10_train_seq2seq_keras')
model = load_model(model_path + '_model.h5')

In [34]:
model.load_weights(model_path + '_weights.h5')

In [35]:
encoder_model = Model(encoder_inputs, encoder_states)
thought_input = [
    Input(shape=(num_neurons,)), Input(shape=(num_neurons,))]
decoder_outputs, state_h, state_c = decoder_lstm(
    decoder_inputs, initial_state=thought_input)
decoder_states = [state_h, state_c]
decoder_outputs = decoder_dense(decoder_outputs)

decoder_model = Model(
    inputs=[decoder_inputs] + thought_input,
    output=[decoder_outputs] + decoder_states)


/Users/hobsonlane/anaconda3/envs/nlpiaenv/lib/python3.6/site-packages/ipykernel_launcher.py:11: UserWarning: Update your `Model` call to the Keras 2 API: `Model(inputs=[<tf.Tenso..., outputs=[<tf.Tenso...)`
  # This is added back by InteractiveShellApp.init_path()

In [ ]:
>>> def decode_sequence(input_seq):
...     thought = encoder_model.predict(input_seq)  # <1>

...     target_seq = np.zeros((1, 1, output_vocab_size))  # <2>
...     target_seq[0, 0, target_token_index[stop_token]
...         ] = 1.  # <3>
...     stop_condition = False
...     generated_sequence = ''

...     while not stop_condition:
...         output_tokens, h, c = decoder_model.predict(
...             [target_seq] + thought) # <4>

...         generated_token_idx = np.argmax(output_tokens[0, -1, :])
...         generated_char = reverse_target_char_index[generated_token_idx]
...         generated_sequence += generated_char
...         if (generated_char == stop_token or
...                 len(generated_sequence) > max_decoder_seq_length
...                 ):  # <5>
...             stop_condition = True

...         target_seq = np.zeros((1, 1, output_vocab_size))  # <6>
...         target_seq[0, 0, generated_token_idx] = 1.
...         thought = [h, c]  # <7>

...     return generated_sequence