In [1]:
import random
import numpy as np
import keras
import tensorflow as tf
from keras import models
from keras import layers
import matplotlib.pyplot as plt
%matplotlib inline


Using TensorFlow backend.

In [2]:
print('tf.__version__='+tf.__version__)
print('keras.__version__='+keras.__version__)


tf.__version__=1.3.0
keras.__version__=1.2.2

In [3]:
# This is the secret algorithm we would like to break
def calculate_y(x):
    x0 = x[0]
    x1 = x[1]
    x2 = x[2]
    x3 = x[3]
    
    if x0==0 and x1==0 and x2==1:
        y = 1
    elif x0==0 and x1==0 and x2==2:
        y = 2
    elif x0==0 and x1==0 and x2==3:
        y = 3
    else:
        y = 0
    
    return y

In [4]:
def create_x_y(nb_samples):
    
    for i in range(nb_samples):

        x0 = int(random.uniform(0, 3))
        x1 = int(random.uniform(0, 3))
        x2 = int(random.uniform(0, 3))
        x3 = int(random.uniform(0, 10))
        x = np.array([x0, x1, x2, x3])
        y = calculate_y(x)
        
        # if only on sample, let's return it
        if i == 0:
            X = np.stack([x])
            Y = np.stack([y])
        # otherwise, stack them together
        else:
            X = np.vstack([X, x])
            Y = np.vstack([Y, y])

    return X, Y

In [5]:
def to_one_hot(labels, dimension=4):
    results = np.zeros((len(labels), dimension))
    for i, label in enumerate(labels):
        results[i, label] = 1.
    return results

In [6]:
#print(to_one_hot([0,1,2,3]))

In [7]:
#print(create_x_y(100))

In [8]:
#train_data, train_targets = create_x_y(int(nb_samples*.8)
#print(train_data.shape)
#print(train_targets.shape)

test_data, test_targets = create_x_y(10)
print(test_targets)
print(test_data.shape)
print(test_data.shape[1])
print(test_targets.shape)


[[0]
 [0]
 [0]
 [0]
 [1]
 [0]
 [0]
 [0]
 [0]
 [0]]
(10, 4)
4
(10, 1)

In [28]:
nb_features = 4
model = models.Sequential()
dense1 = model.add(layers.Dense(64, activation='relu', input_shape=(nb_features,)))
model.add(layers.Dense(64, activation='relu'))
dense2 = model.add(layers.Dense(4, activation='softmax'))

# Compile model
model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])

In [10]:
model.get_weights()


Out[10]:
[array([[ 0.18462187,  0.27944225, -0.15826008, -0.0129047 , -0.271171  ,
          0.15089294,  0.13135508, -0.29143903,  0.10521197, -0.13418977,
         -0.02750462, -0.07742067,  0.26466602,  0.18798637,  0.161407  ,
          0.04837978,  0.28884608,  0.26538807,  0.08402535,  0.12140802,
         -0.2733678 , -0.24578136,  0.27479887,  0.07557422,  0.19712588,
          0.1411739 , -0.21707249,  0.08724728,  0.27339125, -0.03663883,
          0.03280884,  0.06031615, -0.04925507, -0.11493094,  0.24619424,
         -0.17485434,  0.27360344,  0.2771036 ,  0.27220494, -0.20620438,
         -0.14461403, -0.29408708,  0.05026668,  0.11174217, -0.03475875,
         -0.18606074,  0.14765677,  0.22381014,  0.025415  , -0.17301178,
          0.13259381,  0.18251446, -0.21593383,  0.22360033, -0.01319364,
          0.03046182,  0.18188325, -0.17439881,  0.06840646, -0.24104649,
          0.12061509,  0.07001069,  0.22086829, -0.01755664],
        [-0.10359395,  0.10329664, -0.14735012,  0.23399514, -0.0785497 ,
          0.01442954,  0.1121926 , -0.14810656,  0.28846622,  0.07749566,
         -0.01801082, -0.15760492, -0.18429014,  0.22654098,  0.03943652,
          0.00047755, -0.22783819, -0.29348645,  0.11395934, -0.13038123,
          0.24306262,  0.25091296,  0.03681305, -0.25427762, -0.16379657,
          0.22506577,  0.02987799,  0.0425683 , -0.05364574, -0.12406734,
          0.01796675, -0.21475822,  0.01029217, -0.20941725,  0.08697945,
          0.16934708, -0.13708718,  0.1808185 ,  0.23067188,  0.2955386 ,
          0.24905103,  0.02639323,  0.12610301, -0.04993729, -0.2068284 ,
          0.25083578,  0.20313305,  0.13633996,  0.20327008,  0.284382  ,
          0.27142674,  0.08947206,  0.24926531, -0.25214994,  0.14998919,
          0.22282088, -0.2758451 ,  0.22162265,  0.02775428,  0.09172174,
         -0.29445946,  0.2639233 ,  0.15323612,  0.04847291],
        [-0.04289529, -0.27673498,  0.02077729, -0.09915617, -0.16933236,
         -0.19965184, -0.11645083, -0.11473016, -0.28788757, -0.24363004,
          0.17631105, -0.20325938, -0.2766973 ,  0.05015662,  0.06138003,
         -0.24359824, -0.00101924,  0.11004198, -0.08563305, -0.16072458,
          0.10561121, -0.09253009,  0.13635418, -0.23161478,  0.180635  ,
         -0.08374518, -0.2590865 , -0.17647684, -0.2174516 ,  0.2071777 ,
         -0.0274615 ,  0.06043732,  0.16371733,  0.06038606, -0.18433022,
         -0.00249162,  0.02509898, -0.0874528 , -0.03876543,  0.28091747,
         -0.22274277,  0.03312597,  0.14398578, -0.17645057,  0.1380879 ,
         -0.02382463,  0.17533854, -0.15475771,  0.09484905, -0.25352088,
          0.03244922, -0.2918884 , -0.03276274, -0.05260652, -0.18317953,
          0.27295035, -0.13308042, -0.15810859, -0.09791659, -0.27140507,
         -0.04491168, -0.08013049,  0.08014706, -0.1832318 ],
        [ 0.13632649, -0.26633734, -0.26140204, -0.0933492 , -0.02161041,
         -0.24995485,  0.17054886,  0.07753018,  0.13043371,  0.24298614,
         -0.1149243 ,  0.06318602,  0.06740731, -0.05265312,  0.25569457,
         -0.15120058, -0.00073045, -0.1711749 , -0.29052272,  0.19164199,
          0.07535595, -0.26624882,  0.18855456, -0.07160324,  0.04110798,
          0.00161895,  0.05235171, -0.02081433,  0.10931981, -0.01777697,
         -0.2118258 ,  0.13630214,  0.0564169 , -0.09700604, -0.06366754,
         -0.05181347,  0.28022474, -0.27129516,  0.00114587,  0.09404922,
         -0.14023241,  0.04570672,  0.20516217, -0.2593466 , -0.08432068,
          0.06900665,  0.22093934,  0.00812593,  0.00454777, -0.02952218,
         -0.28502598, -0.28390947,  0.26458693, -0.10642111, -0.00952059,
         -0.01736507, -0.27408612,  0.00141117, -0.06338234,  0.13449246,
          0.14566833,  0.18429774,  0.21452904,  0.27480465]],
       dtype=float32),
 array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 array([[-0.1881277 ,  0.11279617,  0.17774932, ..., -0.20515102,
         -0.00652274,  0.06021185],
        [-0.09763677, -0.20621969,  0.16151018, ...,  0.15735589,
         -0.05543292, -0.20065117],
        [-0.21195745,  0.19385608,  0.2089635 , ..., -0.16124971,
         -0.17847109, -0.01337989],
        ...,
        [-0.1823911 ,  0.13630824, -0.12906998, ..., -0.19991858,
         -0.13332227, -0.07555848],
        [ 0.05250476,  0.21475805, -0.03784008, ..., -0.10045817,
          0.03795354,  0.1922005 ],
        [-0.17218575, -0.19780323,  0.15942971, ..., -0.08085802,
         -0.0522698 , -0.03805198]], dtype=float32),
 array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 array([[ 0.19388163, -0.27282813,  0.18539801,  0.03080517],
        [ 0.18311587,  0.24469578,  0.17997092,  0.03269368],
        [ 0.18598872, -0.10569109,  0.1444554 ,  0.11622888],
        [ 0.21840072,  0.15202445,  0.14116555, -0.2476511 ],
        [ 0.26639974, -0.26620582,  0.10555008, -0.14001139],
        [ 0.01611748, -0.2657277 ,  0.23625934, -0.19539805],
        [ 0.2920019 , -0.0464485 , -0.06283689, -0.00258857],
        [ 0.09766164,  0.1534709 , -0.14128283,  0.02123648],
        [-0.13289729, -0.17614865, -0.08047086, -0.1036353 ],
        [ 0.28193778,  0.26135284, -0.04301879,  0.12902841],
        [-0.03914183,  0.00654817, -0.11571515,  0.18020386],
        [-0.23098986,  0.14514554,  0.2041114 , -0.04991589],
        [-0.02591655, -0.08080074,  0.04706493,  0.00217059],
        [-0.26043966, -0.11733107,  0.21074808,  0.10244033],
        [ 0.11344266, -0.1511736 , -0.07233007, -0.24409455],
        [ 0.02296779, -0.16604811, -0.21149111,  0.03440082],
        [ 0.20014268,  0.05903974, -0.11337771,  0.16202033],
        [-0.25060937, -0.05339639,  0.14370432,  0.05106956],
        [-0.02077693,  0.1085701 ,  0.11275235,  0.17128488],
        [-0.24237283, -0.0629468 ,  0.18462649, -0.04627669],
        [-0.12902941, -0.06586058, -0.28310794,  0.21164358],
        [-0.24597506,  0.10543853,  0.03983116,  0.0451991 ],
        [ 0.10400477, -0.25647616, -0.04329026, -0.0524289 ],
        [-0.17039496, -0.19853662,  0.06472814,  0.124506  ],
        [ 0.25828308,  0.02010733,  0.29200238,  0.14651117],
        [ 0.0386939 , -0.21985349,  0.05782807,  0.02719176],
        [-0.24513832, -0.22972563, -0.22479871,  0.00339112],
        [-0.14898176, -0.11322409,  0.25605243, -0.27579036],
        [-0.06608813,  0.15431267, -0.17473903,  0.13022655],
        [-0.20730239, -0.19335388, -0.00677797,  0.0644322 ],
        [-0.21078771,  0.29412878, -0.28098848,  0.01778334],
        [-0.26151145,  0.1898112 ,  0.02097842,  0.16706035],
        [ 0.21748567, -0.2770859 , -0.04349196, -0.11320207],
        [-0.14706925, -0.2760794 , -0.17149863, -0.20236427],
        [ 0.01476482, -0.20841414, -0.26071018, -0.1511343 ],
        [ 0.085897  ,  0.18317622, -0.20146485,  0.02320623],
        [-0.19390926,  0.26469928, -0.16485463, -0.28280306],
        [ 0.05677184, -0.2160235 , -0.28780124,  0.27615178],
        [ 0.17285988, -0.05617703, -0.2566166 ,  0.20342505],
        [-0.18715954,  0.27926052,  0.15893605,  0.11090067],
        [ 0.01141909, -0.11619836,  0.06584641, -0.28348136],
        [ 0.26088738, -0.20434025, -0.00617346,  0.05059189],
        [ 0.03290302,  0.18821001, -0.2770133 , -0.18319052],
        [-0.03349933,  0.20786399,  0.11005732,  0.12267515],
        [ 0.28925687, -0.02888316, -0.16531397,  0.02430913],
        [-0.2013261 ,  0.20014462,  0.19391838,  0.06940219],
        [-0.26816317, -0.20158489,  0.00846493, -0.23864913],
        [-0.10848144, -0.1551514 ,  0.12686306, -0.18464291],
        [-0.1241405 , -0.25531995,  0.07984146,  0.2916832 ],
        [-0.10349472, -0.16504472, -0.18904966, -0.1907214 ],
        [ 0.24123192, -0.08368804, -0.21619907, -0.00365406],
        [ 0.21249223, -0.0868288 ,  0.10485646,  0.24572796],
        [-0.14529236,  0.19011346,  0.0881691 ,  0.05120626],
        [ 0.19380048, -0.20086846,  0.29266477, -0.04797399],
        [-0.04236731,  0.02455261,  0.11414087, -0.28181276],
        [ 0.18192637, -0.11131908,  0.18520212, -0.05494311],
        [-0.13305593,  0.2554304 ,  0.20697927,  0.1687716 ],
        [ 0.0337072 ,  0.08384207, -0.22301862, -0.03012869],
        [ 0.01669398, -0.21113437, -0.16615874,  0.28473717],
        [-0.10157208, -0.02560726,  0.22970974, -0.05118692],
        [-0.06794497, -0.17114714, -0.09204547, -0.26470023],
        [-0.1850101 ,  0.23023558, -0.05054061,  0.10923141],
        [-0.29699972,  0.08620593,  0.22563255,  0.0773128 ],
        [-0.21545339,  0.2664702 ,  0.17890075,  0.13797131]],
       dtype=float32),
 array([0., 0., 0., 0.], dtype=float32)]

In [34]:
# ONLINE, STEP BY STEP in ONLINE MODE !!!
for i in range(2000):
    x, y = create_x_y(1)
    history = model.fit(x, to_one_hot(y), nb_epoch=5, batch_size=1,  verbose=1)

In [12]:
# ALL IN ONE STEP ('OSSIFIED')
x, y = create_x_y(2000)
history = model.fit(x, to_one_hot(y), nb_epoch=20, batch_size=1,  verbose=1)


Epoch 1/20
2000/2000 [==============================] - 2s - loss: 0.1315 - acc: 0.9550     
Epoch 2/20
2000/2000 [==============================] - 2s - loss: 0.0377 - acc: 0.9835     
Epoch 3/20
2000/2000 [==============================] - 2s - loss: 0.0152 - acc: 0.9970     
Epoch 4/20
2000/2000 [==============================] - 2s - loss: 0.0057 - acc: 0.9975     
Epoch 5/20
2000/2000 [==============================] - 2s - loss: 0.0049 - acc: 0.9985     
Epoch 6/20
2000/2000 [==============================] - 2s - loss: 8.5021e-04 - acc: 0.9995     
Epoch 7/20
2000/2000 [==============================] - 2s - loss: 1.6301e-04 - acc: 1.0000     
Epoch 8/20
2000/2000 [==============================] - 2s - loss: 1.3819e-04 - acc: 1.0000     
Epoch 9/20
2000/2000 [==============================] - 2s - loss: 7.3891e-06 - acc: 1.0000     
Epoch 10/20
2000/2000 [==============================] - 2s - loss: 9.6046e-06 - acc: 1.0000     
Epoch 11/20
2000/2000 [==============================] - 2s - loss: 2.1828e-07 - acc: 1.0000     
Epoch 12/20
2000/2000 [==============================] - 2s - loss: 1.2299e-07 - acc: 1.0000     
Epoch 13/20
2000/2000 [==============================] - 2s - loss: 1.1921e-07 - acc: 1.0000     
Epoch 14/20
2000/2000 [==============================] - 2s - loss: 1.1921e-07 - acc: 1.0000     
Epoch 15/20
2000/2000 [==============================] - 2s - loss: 1.1921e-07 - acc: 1.0000     
Epoch 16/20
2000/2000 [==============================] - 2s - loss: 1.1921e-07 - acc: 1.0000     
Epoch 17/20
2000/2000 [==============================] - 2s - loss: 1.1921e-07 - acc: 1.0000     
Epoch 18/20
2000/2000 [==============================] - 2s - loss: 1.1921e-07 - acc: 1.0000     
Epoch 19/20
2000/2000 [==============================] - 2s - loss: 1.1921e-07 - acc: 1.0000     
Epoch 20/20
2000/2000 [==============================] - 2s - loss: 1.1921e-07 - acc: 1.0000     

In [13]:
test_data


Out[13]:
array([[2, 1, 0, 4],
       [1, 1, 1, 2],
       [0, 0, 0, 2],
       [2, 0, 0, 9],
       [0, 0, 1, 4],
       [1, 2, 2, 2],
       [2, 0, 1, 6],
       [1, 0, 0, 8],
       [2, 1, 0, 9],
       [2, 1, 1, 1]])

In [14]:
arr = model.predict(test_data)[0]
arr


Out[14]:
array([1., 0., 0., 0.], dtype=float32)

In [15]:
np.argmax(arr)


Out[15]:
0

In [16]:
def max_ndarray(array):
    max_v = -1.0
    max_i = -1
    for i in range(len(array)):
        val = array[i]
        if val > max_v:
            max_v = val
            max_i = i
    return max_i, max_v

max_ndarray(arr)


Out[16]:
(0, 1.0)

In [17]:
#unit test
arr = np.array([0.1 , 0.4 , 0.3 , 0.2 ])
assert  max_ndarray(arr) == (1, 0.4)

In [18]:
np.amax(model.predict(test_data)[0])


Out[18]:
1.0

In [19]:
#np.int(np.round(model.predict(test_data).dot(np.array([1, 2, 3, 4]))-1))

In [20]:
y = test_targets
y


Out[20]:
array([[0],
       [0],
       [0],
       [0],
       [1],
       [0],
       [0],
       [0],
       [0],
       [0]])

In [21]:
td0 = np.stack([test_data[3]])
td0


Out[21]:
array([[2, 0, 0, 9]])

In [22]:
model.predict(td0)


Out[22]:
array([[1., 0., 0., 0.]], dtype=float32)

In [23]:
np.round(model.predict(td0).dot(np.array([1, 2, 3, 4]))-1)


Out[23]:
array([0.])

In [26]:
loss = history.history['loss']
epochs = range(len(loss))
plt.plot(epochs, loss, 'bo', label='Training Loss')
plt.show()



In [27]:
acc = history.history['acc']
epochs = range(len(acc))
plt.plot(epochs, acc, 'bo', label='Training Loss')
plt.show()



In [ ]:


In [35]:
loss = history.history['loss']
epochs = range(len(loss))
plt.plot(epochs, loss, 'bo', label='Training Loss')
plt.show()


---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-35-da04cd6f72d4> in <module>()
----> 1 loss = history.history['loss']
      2 epochs = range(len(loss))
      3 plt.plot(epochs, loss, 'bo', label='Training Loss')
      4 plt.show()

AttributeError: 'list' object has no attribute 'history'

In [31]:
acc = history.history['acc']
epochs = range(len(acc))
plt.plot(epochs, acc, 'bo', label='Training Loss')
plt.show()



In [ ]: