Learning addition rules

This example comes directly from keras.


In [12]:
%pylab inline

from keras.models import Sequential
from keras.engine.training import slice_X
from keras.layers.core import Activation, Dense, RepeatVector
from keras.layers import recurrent
from keras.layers.wrappers import TimeDistributed
import numpy as np
from six.moves import range


Populating the interactive namespace from numpy and matplotlib

In [2]:
class CharacterTable(object):
    '''
    Given a set of characters:
    + Encode them to a one hot integer representation
    + Decode the one hot integer representation to their character output
    + Decode a vector of probabilties to their character output
    '''
    def __init__(self, chars, maxlen):
        self.chars = sorted(set(chars))
        self.char_indices = dict((c, i) for i, c in enumerate(self.chars))
        self.indices_char = dict((i, c) for i, c in enumerate(self.chars))
        self.maxlen = maxlen

    def encode(self, C, maxlen=None):
        maxlen = maxlen if maxlen else self.maxlen
        X = np.zeros((maxlen, len(self.chars)))
        for i, c in enumerate(C):
            X[i, self.char_indices[c]] = 1
        return X

    def decode(self, X, calc_argmax=True):
        if calc_argmax:
            X = X.argmax(axis=-1)
        return ''.join(self.indices_char[x] for x in X)

In [15]:
class colors:
    ok = '\033[92m'
    fail = '\033[91m'
    close = '\033[0m'

In [4]:
# Parameters for the model and dataset
TRAINING_SIZE = 50000
DIGITS = 3
INVERT = True
# Try replacing GRU, or SimpleRNN
RNN = recurrent.LSTM
HIDDEN_SIZE = 128
BATCH_SIZE = 128
LAYERS = 1
MAXLEN = DIGITS + 1 + DIGITS

chars = '0123456789+ '
ctable = CharacterTable(chars, MAXLEN)

In [5]:
questions = []
expected = []
seen = set()
print('Generating data...')
while len(questions) < TRAINING_SIZE:
    f = lambda: int(''.join(np.random.choice(list('0123456789')) for i in range(np.random.randint(1, DIGITS + 1))))
    a, b = f(), f()
    # Skip any addition questions we've already seen
    # Also skip any such that X+Y == Y+X (hence the sorting)
    key = tuple(sorted((a, b)))
    if key in seen:
        continue
    seen.add(key)
    # Pad the data with spaces such that it is always MAXLEN
    q = '{}+{}'.format(a, b)
    query = q + ' ' * (MAXLEN - len(q))
    ans = str(a + b)
    # Answers can be of maximum size DIGITS + 1
    ans += ' ' * (DIGITS + 1 - len(ans))
    if INVERT:
        query = query[::-1]
    questions.append(query)
    expected.append(ans)
print('Total addition questions:', len(questions))


Generating data...
Total addition questions: 50000

In [6]:
print('Vectorization...')
X = np.zeros((len(questions), MAXLEN, len(chars)), dtype=np.bool)
y = np.zeros((len(questions), DIGITS + 1, len(chars)), dtype=np.bool)
for i, sentence in enumerate(questions):
    X[i] = ctable.encode(sentence, maxlen=MAXLEN)
for i, sentence in enumerate(expected):
    y[i] = ctable.encode(sentence, maxlen=DIGITS + 1)

# Shuffle (X, y) in unison as the later parts of X will almost all be larger digits
indices = np.arange(len(y))
np.random.shuffle(indices)
X = X[indices]
y = y[indices]


Vectorization...

In [8]:
# Explicitly set apart 10% for validation data that we never train over
split_at = int(len(X) - len(X) / 10)
(X_train, X_val) = (slice_X(X, 0, split_at), slice_X(X, split_at))
(y_train, y_val) = (y[:split_at], y[split_at:])

print(X_train.shape)
print(y_train.shape)


(45000, 7, 12)
(45000, 4, 12)

In [13]:
model = Sequential()
model.add(RNN(HIDDEN_SIZE, input_shape=(MAXLEN, len(chars))))
model.add(RepeatVector(DIGITS + 1))
for _ in range(LAYERS):
    model.add(RNN(HIDDEN_SIZE, return_sequences=True))
    
model.add(TimeDistributed(Dense(len(chars))))
model.add(Activation('softmax'))

model.compile(loss='categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

In [17]:
for iteration in range(1, 200):
    print()
    print('-' * 50)
    print('Iteration', iteration)
    model.fit(X_train, y_train, batch_size=BATCH_SIZE, nb_epoch=1,
              validation_data=(X_val, y_val))
    ###
    # Select 10 samples from the validation set at random so we can visualize errors
    for i in range(10):
        ind = np.random.randint(0, len(X_val))
        rowX, rowy = X_val[np.array([ind])], y_val[np.array([ind])]
        preds = model.predict_classes(rowX, verbose=0)
        q = ctable.decode(rowX[0])
        correct = ctable.decode(rowy[0])
        guess = ctable.decode(preds[0], calc_argmax=False)
        print('Q', q[::-1] if INVERT else q)
        print('T', correct)
        print(colors.ok + '☑' + colors.close if correct == guess else colors.fail + '☒' + colors.close, guess)
        print('---')


--------------------------------------------------
Iteration 1
Train on 45000 samples, validate on 5000 samples
Epoch 1/1
45000/45000 [==============================] - 46s - loss: 1.6038 - acc: 0.4012 - val_loss: 1.5121 - val_acc: 0.4376
Q 735+8  
T 743 
 839 
---
Q 789+70 
T 859 
 808 
---
Q 46+373 
T 419 
 499 
---
Q 872+902
T 1774
 1710
---
Q 191+9  
T 200 
 110 
---
Q 38+900 
T 938 
 900 
---
Q 621+9  
T 630 
 129 
---
Q 616+73 
T 689 
 789 
---
Q 834+74 
T 908 
 801 
---
Q 783+28 
T 811 
 801 
---

--------------------------------------------------
Iteration 2
Train on 45000 samples, validate on 5000 samples
Epoch 1/1
45000/45000 [==============================] - 43s - loss: 1.4134 - acc: 0.4756 - val_loss: 1.3410 - val_acc: 0.4965
Q 553+92 
T 645 
 666 
---
Q 449+16 
T 465 
 454 
---
Q 878+96 
T 974 
 886 
---
Q 263+8  
T 271 
 286 
---
Q 44+87  
T 131 
 141 
---
Q 85+853 
T 938 
 949 
---
Q 36+735 
T 771 
 744 
---
Q 381+54 
T 435 
 499 
---
Q 3+951  
T 954 
 944 
---
Q 875+500
T 1375
 1268
---

--------------------------------------------------
Iteration 3
Train on 45000 samples, validate on 5000 samples
Epoch 1/1
45000/45000 [==============================] - 40s - loss: 1.2506 - acc: 0.5389 - val_loss: 1.1851 - val_acc: 0.5681
Q 525+622
T 1147
 1164
---
Q 64+807 
T 871 
 860 
---
Q 60+892 
T 952 
 900 
---
Q 985+2  
T 987 
 990 
---
Q 191+840
T 1031
 1009
---
Q 295+67 
T 362 
 355 
---
Q 8+622  
T 630 
 634 
---
Q 50+63  
T 113 
 101 
---
Q 711+48 
T 759 
 764 
---
Q 660+27 
T 687 
 673 
---

--------------------------------------------------
Iteration 4
Train on 45000 samples, validate on 5000 samples
Epoch 1/1
45000/45000 [==============================] - 43s - loss: 1.1156 - acc: 0.5934 - val_loss: 1.0566 - val_acc: 0.6168
Q 92+405 
T 497 
 402 
---
Q 7+680  
T 687 
 685 
---
Q 815+80 
T 895 
 896 
---
Q 517+0  
T 517 
 510 
---
Q 3+371  
T 374 
 374 
---
Q 920+68 
T 988 
 980 
---
Q 176+726
T 902 
 901 
---
Q 880+560
T 1440
 1457
---
Q 153+710
T 863 
 805 
---
Q 589+20 
T 609 
 601 
---

--------------------------------------------------
Iteration 5
Train on 45000 samples, validate on 5000 samples
Epoch 1/1
45000/45000 [==============================] - 40s - loss: 0.9861 - acc: 0.6427 - val_loss: 0.9210 - val_acc: 0.6675
Q 711+351
T 1062
 1043
---
Q 378+9  
T 387 
 396 
---
Q 613+1  
T 614 
 615 
---
Q 215+540
T 755 
 754 
---
Q 2+233  
T 235 
 235 
---
Q 910+78 
T 988 
 988 
---
Q 4+198  
T 202 
 200 
---
Q 435+73 
T 508 
 508 
---
Q 898+4  
T 902 
 901 
---
Q 951+709
T 1660
 1602
---

--------------------------------------------------
Iteration 6
Train on 45000 samples, validate on 5000 samples
Epoch 1/1
45000/45000 [==============================] - 40s - loss: 0.7951 - acc: 0.7114 - val_loss: 0.6767 - val_acc: 0.7600
Q 155+0  
T 155 
 155 
---
Q 993+435
T 1428
 1438
---
Q 330+214
T 544 
 544 
---
Q 513+472
T 985 
 996 
---
Q 50+938 
T 988 
 999 
---
Q 71+384 
T 455 
 455 
---
Q 603+6  
T 609 
 619 
---
Q 136+22 
T 158 
 158 
---
Q 962+396
T 1358
 1358
---
Q 786+573
T 1359
 1358
---

--------------------------------------------------
Iteration 7
Train on 45000 samples, validate on 5000 samples
Epoch 1/1
45000/45000 [==============================] - 40s - loss: 0.5704 - acc: 0.8093 - val_loss: 0.4900 - val_acc: 0.8495
Q 18+655 
T 673 
 673 
---
Q 204+657
T 861 
 871 
---
Q 5+340  
T 345 
 345 
---
Q 588+49 
T 637 
 637 
---
Q 307+872
T 1179
 1189
---
Q 657+458
T 1115
 1115
---
Q 519+69 
T 588 
 587 
---
Q 289+990
T 1279
 1299
---
Q 227+3  
T 230 
 230 
---
Q 603+675
T 1278
 1278
---

--------------------------------------------------
Iteration 8
Train on 45000 samples, validate on 5000 samples
Epoch 1/1
45000/45000 [==============================] - 41s - loss: 0.4116 - acc: 0.8828 - val_loss: 0.3778 - val_acc: 0.8882
Q 0+263  
T 263 
 264 
---
Q 260+17 
T 277 
 277 
---
Q 739+513
T 1252
 1251
---
Q 5+543  
T 548 
 549 
---
Q 484+397
T 881 
 860 
---
Q 67+845 
T 912 
 912 
---
Q 71+47  
T 118 
 118 
---
Q 470+580
T 1050
 1040
---
Q 930+2  
T 932 
 932 
---
Q 553+71 
T 624 
 625 
---

--------------------------------------------------
Iteration 9
Train on 45000 samples, validate on 5000 samples
Epoch 1/1
45000/45000 [==============================] - 40s - loss: 0.3035 - acc: 0.9260 - val_loss: 0.2654 - val_acc: 0.9378
Q 91+102 
T 193 
 193 
---
Q 307+451
T 758 
 758 
---
Q 457+23 
T 480 
 480 
---
Q 622+2  
T 624 
 624 
---
Q 87+584 
T 671 
 671 
---
Q 32+165 
T 197 
 297 
---
Q 835+7  
T 842 
 842 
---
Q 295+86 
T 381 
 381 
---
Q 117+32 
T 149 
 149 
---
Q 5+840  
T 845 
 845 
---

--------------------------------------------------
Iteration 10
Train on 45000 samples, validate on 5000 samples
Epoch 1/1
45000/45000 [==============================] - 38s - loss: 0.2308 - acc: 0.9488 - val_loss: 0.2142 - val_acc: 0.9471
Q 8+515  
T 523 
 523 
---
Q 95+823 
T 918 
 918 
---
Q 755+60 
T 815 
 815 
---
Q 151+475
T 626 
 626 
---
Q 15+752 
T 767 
 767 
---
Q 20+622 
T 642 
 642 
---
Q 5+340  
T 345 
 345 
---
Q 652+121
T 773 
 763 
---
Q 478+839
T 1317
 1206
---
Q 70+508 
T 578 
 578 
---

--------------------------------------------------
Iteration 11
Train on 45000 samples, validate on 5000 samples
Epoch 1/1
45000/45000 [==============================] - 39s - loss: 0.1761 - acc: 0.9648 - val_loss: 0.1703 - val_acc: 0.9625
Q 494+815
T 1309
 1309
---
Q 506+75 
T 581 
 581 
---
Q 198+258
T 456 
 446 
---
Q 30+46  
T 76  
 76  
---
Q 976+2  
T 978 
 978 
---
Q 758+76 
T 834 
 834 
---
Q 7+376  
T 383 
 383 
---
Q 722+70 
T 792 
 792 
---
Q 277+56 
T 333 
 333 
---
Q 263+8  
T 271 
 271 
---

--------------------------------------------------
Iteration 12
Train on 45000 samples, validate on 5000 samples
Epoch 1/1
45000/45000 [==============================] - 39s - loss: 0.1383 - acc: 0.9739 - val_loss: 0.1368 - val_acc: 0.9712
Q 934+103
T 1037
 1047
---
Q 58+942 
T 1000
 1000
---
Q 333+75 
T 408 
 408 
---
Q 3+3    
T 6   
 4   
---
Q 90+37  
T 127 
 127 
---
Q 6+117  
T 123 
 123 
---
Q 408+813
T 1221
 1221
---
Q 467+6  
T 473 
 473 
---
Q 444+825
T 1269
 1269
---
Q 600+764
T 1364
 1364
---

--------------------------------------------------
Iteration 13
Train on 45000 samples, validate on 5000 samples
Epoch 1/1
45000/45000 [==============================] - 39s - loss: 0.1232 - acc: 0.9743 - val_loss: 0.1151 - val_acc: 0.9758
Q 10+40  
T 50  
 50  
---
Q 489+66 
T 555 
 555 
---
Q 546+67 
T 613 
 613 
---
Q 41+111 
T 152 
 152 
---
Q 95+523 
T 618 
 618 
---
Q 121+55 
T 176 
 176 
---
Q 282+10 
T 292 
 292 
---
Q 47+499 
T 546 
 545 
---
Q 247+8  
T 255 
 255 
---
Q 885+83 
T 968 
 968 
---

--------------------------------------------------
Iteration 14
Train on 45000 samples, validate on 5000 samples
Epoch 1/1
45000/45000 [==============================] - 39s - loss: 0.0863 - acc: 0.9862 - val_loss: 0.0857 - val_acc: 0.9839
Q 165+521
T 686 
 686 
---
Q 23+993 
T 1016
 1016
---
Q 57+45  
T 102 
 102 
---
Q 577+467
T 1044
 1044
---
Q 37+903 
T 940 
 940 
---
Q 711+48 
T 759 
 759 
---
Q 516+296
T 812 
 812 
---
Q 8+567  
T 575 
 575 
---
Q 73+668 
T 741 
 741 
---
Q 129+389
T 518 
 518 
---

--------------------------------------------------
Iteration 15
Train on 45000 samples, validate on 5000 samples
Epoch 1/1
45000/45000 [==============================] - 45s - loss: 0.0764 - acc: 0.9870 - val_loss: 0.0810 - val_acc: 0.9832
Q 569+841
T 1410
 1410
---
Q 381+54 
T 435 
 435 
---
Q 502+87 
T 589 
 589 
---
Q 749+29 
T 778 
 778 
---
Q 6+138  
T 144 
 144 
---
Q 29+999 
T 1028
 1028
---
Q 341+899
T 1240
 1240
---
Q 4+990  
T 994 
 993 
---
Q 0+476  
T 476 
 476 
---
Q 949+89 
T 1038
 1038
---

--------------------------------------------------
Iteration 16
Train on 45000 samples, validate on 5000 samples
Epoch 1/1
13056/45000 [=======>......................] - ETA: 33s - loss: 0.0677 - acc: 0.9878
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-17-49d6041109fa> in <module>()
      6     print('Iteration', iteration)
      7     model.fit(X_train, y_train, batch_size=BATCH_SIZE, nb_epoch=1,
----> 8               validation_data=(X_val, y_val))
      9     ###
     10     # Select 10 samples from the validation set at random so we can visualize errors

/Users/taylor/anaconda3/lib/python3.5/site-packages/keras/models.py in fit(self, x, y, batch_size, nb_epoch, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, **kwargs)
    395                               shuffle=shuffle,
    396                               class_weight=class_weight,
--> 397                               sample_weight=sample_weight)
    398 
    399     def evaluate(self, x, y, batch_size=32, verbose=1,

/Users/taylor/anaconda3/lib/python3.5/site-packages/keras/engine/training.py in fit(self, x, y, batch_size, nb_epoch, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight)
   1009                               verbose=verbose, callbacks=callbacks,
   1010                               val_f=val_f, val_ins=val_ins, shuffle=shuffle,
-> 1011                               callback_metrics=callback_metrics)
   1012 
   1013     def evaluate(self, x, y, batch_size=32, verbose=1, sample_weight=None):

/Users/taylor/anaconda3/lib/python3.5/site-packages/keras/engine/training.py in _fit_loop(self, f, ins, out_labels, batch_size, nb_epoch, verbose, callbacks, val_f, val_ins, shuffle, callback_metrics)
    747                 batch_logs['size'] = len(batch_ids)
    748                 callbacks.on_batch_begin(batch_index, batch_logs)
--> 749                 outs = f(ins_batch)
    750                 if type(outs) != list:
    751                     outs = [outs]

/Users/taylor/anaconda3/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py in __call__(self, inputs)
    579         feed_dict = dict(zip(names, inputs))
    580         session = get_session()
--> 581         updated = session.run(self.outputs + self.updates, feed_dict=feed_dict)
    582         return updated[:len(self.outputs)]
    583 

/Users/taylor/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict)
    313         `Tensor` that doesn't exist.
    314     """
--> 315     return self._run(None, fetches, feed_dict)
    316 
    317   def partial_run(self, handle, fetches, feed_dict=None):

/Users/taylor/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict)
    509     # Run request and get response.
    510     results = self._do_run(handle, target_list, unique_fetches,
--> 511                            feed_dict_string)
    512 
    513     # User may have fetched the same tensor multiple times, but we

/Users/taylor/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict)
    562     if handle is None:
    563       return self._do_call(_run_fn, self._session, feed_dict, fetch_list,
--> 564                            target_list)
    565     else:
    566       return self._do_call(_prun_fn, self._session, handle, feed_dict,

/Users/taylor/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
    569   def _do_call(self, fn, *args):
    570     try:
--> 571       return fn(*args)
    572     except tf_session.StatusNotOK as e:
    573       e_type, e_value, e_traceback = sys.exc_info()

/Users/taylor/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py in _run_fn(session, feed_dict, fetch_list, target_list)
    553       # Ensure any changes to the graph are reflected in the runtime.
    554       self._extend_graph()
--> 555       return tf_session.TF_Run(session, feed_dict, fetch_list, target_list)
    556 
    557     def _prun_fn(session, handle, feed_dict, fetch_list):

KeyboardInterrupt: 

In [ ]: