In [1]:
import numpy as np
from layer import *

X_xor = np.array([[x, y] for x in range(2) for y in range(2)]) T_xor = np.array([[x ^ y] for x, y in X_xor])

mdl_xor1 = Model([ FullyConnectedLayer(2, 3), Sigmoid(), FullyConnectedLayer(3, 1), Sigmoid() ], 0.01, 0)

for _ in range(100000): mdl_xor1.train(X_xor, T_xor)

mdl_xor1.test(X_xor, T_xor)


In [2]:
def cossim(o, t):
    return (o * t).sum() / (np.sqrt((o ** 2).sum()) * np.sqrt((t ** 2).sum()))

In [3]:
from main import *

In [14]:
mdl = Model([
    Flatten([-1, 32, 32, 3], [-1, 32 * 32 * 3]),
    FullyConnectedLayer(32 * 32 * 3, 96),
    Sigmoid(),
    FullyConnectedLayer(96, 10),
    SoftMax(is_output=True),
], 0.001, 0, 1000)

In [16]:
for i in range(1):
    mdl.train(X, T)
    print(i, np.array([cossim(x.round(3), t) for x, t in zip(mdl.layers[-1].val, T)]).mean())


/home/pluvian/code/eps/cnn-py/layer.py:45: RuntimeWarning: overflow encountered in exp
  self.val = 1 / (1 + np.exp(-x))
0 0.31503910008

In [17]:
np.array([cossim(x.round(3), t) for x, t in zip(mdl.layers[-1].val, T)])


Out[17]:
array([ 0.28278267,  0.32362906,  0.32362906,  0.30477688,  0.33305515,
        0.33305515,  0.32362906,  0.30477688,  0.39589574,  0.317345  ,
        0.30477688,  0.30477688,  0.30477688,  0.32362906,  0.32362906,
        0.32362906,  0.32362906,  0.317345  ,  0.32362906,  0.28278267,
        0.30477688,  0.317345  ,  0.28278267,  0.28278267,  0.32362906,
        0.28278267,  0.317345  ,  0.27649861,  0.30477688,  0.28278267,
        0.28278267,  0.32362906,  0.33305515,  0.317345  ,  0.30477688,
        0.28278267,  0.317345  ,  0.30477688,  0.317345  ,  0.317345  ,
        0.27649861,  0.32362906,  0.32362906,  0.30477688,  0.33305515,
        0.33305515,  0.33305515,  0.32362906,  0.32362906,  0.28278267,
        0.32362906,  0.27649861,  0.30477688,  0.32362906,  0.32362906,
        0.32362906,  0.27649861,  0.32362906,  0.30477688,  0.317345  ,
        0.33305515,  0.33305515,  0.39589574,  0.32362906,  0.33305515,
        0.33305515,  0.30477688,  0.32362906,  0.30477688,  0.39589574,
        0.27649861,  0.32362906,  0.28278267,  0.30477688,  0.317345  ,
        0.33305515,  0.32362906,  0.28278267,  0.317345  ,  0.33305515,
        0.317345  ,  0.27649861,  0.30477688,  0.27649861,  0.30477688,
        0.30477688,  0.30477688,  0.30477688,  0.32362906,  0.30477688,
        0.32362906,  0.317345  ,  0.39589574,  0.28278267,  0.33305515,
        0.28278267,  0.33305515,  0.33305515,  0.30477688,  0.33305515,
        0.39589574,  0.317345  ,  0.32362906,  0.28278267,  0.28278267,
        0.33305515,  0.39589574,  0.27649861,  0.32362906,  0.32362906,
        0.32362906,  0.39589574,  0.33305515,  0.30477688,  0.30477688,
        0.28278267,  0.28278267,  0.28278267,  0.32362906,  0.33305515,
        0.32362906,  0.32362906,  0.32362906,  0.32362906,  0.28278267,
        0.28278267,  0.33305515,  0.32362906,  0.27649861,  0.28278267,
        0.30477688,  0.30477688,  0.28278267,  0.30477688,  0.33305515,
        0.39589574,  0.33305515,  0.33305515,  0.32362906,  0.39589574,
        0.33305515,  0.317345  ,  0.317345  ,  0.28278267,  0.32362906,
        0.30477688,  0.32362906,  0.32362906,  0.27649861,  0.30477688,
        0.317345  ,  0.28278267,  0.30477688,  0.30477688,  0.28278267,
        0.39589574,  0.27649861,  0.27649861,  0.30477688,  0.317345  ,
        0.33305515,  0.39589574,  0.30477688,  0.30477688,  0.28278267,
        0.28278267,  0.32362906,  0.27649861,  0.33305515,  0.317345  ,
        0.39589574,  0.32362906,  0.30477688,  0.27649861,  0.317345  ,
        0.30477688,  0.33305515,  0.27649861,  0.30477688,  0.28278267,
        0.30477688,  0.30477688,  0.27649861,  0.27649861,  0.33305515,
        0.28278267,  0.32362906,  0.28278267,  0.32362906,  0.28278267,
        0.39589574,  0.30477688,  0.39589574,  0.39589574,  0.32362906,
        0.27649861,  0.32362906,  0.317345  ,  0.27649861,  0.28278267,
        0.28278267,  0.33305515,  0.32362906,  0.317345  ,  0.28278267,
        0.32362906,  0.33305515,  0.317345  ,  0.32362906,  0.28278267,
        0.28278267,  0.30477688,  0.33305515,  0.28278267,  0.32362906,
        0.27649861,  0.39589574,  0.27649861,  0.32362906,  0.32362906,
        0.28278267,  0.39589574,  0.39589574,  0.28278267,  0.28278267,
        0.32362906,  0.33305515,  0.33305515,  0.28278267,  0.317345  ,
        0.30477688,  0.28278267,  0.28278267,  0.28278267,  0.28278267,
        0.28278267,  0.33305515,  0.30477688,  0.33305515,  0.27649861,
        0.39589574,  0.317345  ,  0.28278267,  0.28278267,  0.39589574,
        0.28278267,  0.39589574,  0.30477688,  0.28278267,  0.28278267,
        0.33305515,  0.317345  ,  0.39589574,  0.317345  ,  0.30477688,
        0.33305515,  0.30477688,  0.33305515,  0.317345  ,  0.39589574,
        0.27649861,  0.33305515,  0.33305515,  0.30477688,  0.28278267,
        0.32362906,  0.317345  ,  0.30477688,  0.30477688,  0.32362906,
        0.32362906,  0.32362906,  0.30477688,  0.32362906,  0.32362906,
        0.33305515,  0.28278267,  0.27649861,  0.32362906,  0.28278267,
        0.39589574,  0.32362906,  0.33305515,  0.32362906,  0.28278267,
        0.27649861,  0.28278267,  0.317345  ,  0.32362906,  0.30477688,
        0.39589574,  0.39589574,  0.28278267,  0.28278267,  0.30477688,
        0.32362906,  0.30477688,  0.27649861,  0.28278267,  0.30477688,
        0.32362906,  0.33305515,  0.33305515,  0.32362906,  0.33305515,
        0.27649861,  0.32362906,  0.32362906,  0.28278267,  0.39589574,
        0.30477688,  0.33305515,  0.33305515,  0.28278267,  0.317345  ,
        0.317345  ,  0.32362906,  0.28278267,  0.30477688,  0.32362906,
        0.30477688,  0.30477688,  0.32362906,  0.33305515,  0.27649861,
        0.33305515,  0.28278267,  0.28278267,  0.39589574,  0.30477688,
        0.33305515,  0.317345  ,  0.28278267,  0.317345  ,  0.317345  ,
        0.32362906,  0.30477688,  0.27649861,  0.30477688,  0.27649861,
        0.32362906,  0.28278267,  0.317345  ,  0.30477688,  0.28278267,
        0.30477688,  0.30477688,  0.28278267,  0.28278267,  0.28278267,
        0.28278267,  0.28278267,  0.28278267,  0.39589574,  0.33305515,
        0.28278267,  0.32362906,  0.32362906,  0.32362906,  0.27649861,
        0.32362906,  0.28278267,  0.30477688,  0.30477688,  0.33305515,
        0.39589574,  0.30477688,  0.317345  ,  0.28278267,  0.32362906,
        0.317345  ,  0.28278267,  0.30477688,  0.28278267,  0.27649861,
        0.33305515,  0.28278267,  0.317345  ,  0.30477688,  0.39589574,
        0.27649861,  0.30477688,  0.30477688,  0.32362906,  0.317345  ,
        0.32362906,  0.30477688,  0.28278267,  0.30477688,  0.33305515,
        0.30477688,  0.30477688,  0.28278267,  0.33305515,  0.30477688,
        0.317345  ,  0.33305515,  0.39589574,  0.30477688,  0.30477688,
        0.32362906,  0.28278267,  0.32362906,  0.32362906,  0.28278267,
        0.28278267,  0.32362906,  0.28278267,  0.32362906,  0.28278267,
        0.39589574,  0.32362906,  0.30477688,  0.30477688,  0.30477688,
        0.28278267,  0.317345  ,  0.39310593,  0.39589574,  0.32362906,
        0.30477688,  0.32362906,  0.30477688,  0.32362906,  0.27649861,
        0.32362906,  0.27649861,  0.33305515,  0.32362906,  0.30477688,
        0.39589574,  0.27649861,  0.33305515,  0.30477688,  0.30477688,
        0.30477688,  0.28278267,  0.28278267,  0.32362906,  0.28278267,
        0.30477688,  0.39589574,  0.39589574,  0.32362906,  0.32362906,
        0.317345  ,  0.317345  ,  0.30477688,  0.28278267,  0.30477688,
        0.27649861,  0.28278267,  0.28278267,  0.28278267,  0.33305515,
        0.28278267,  0.39589574,  0.28278267,  0.30477688,  0.39589574,
        0.39589574,  0.33305515,  0.27649861,  0.32362906,  0.28278267,
        0.39589574,  0.33305515,  0.28278267,  0.28278267,  0.30477688,
        0.30477688,  0.27649861,  0.32362906,  0.28278267,  0.32362906,
        0.39589574,  0.317345  ,  0.30477688,  0.30477688,  0.317345  ,
        0.32362906,  0.28278267,  0.33305515,  0.32362906,  0.30477688,
        0.39589574,  0.33305515,  0.39589574,  0.28278267,  0.30477688,
        0.30477688,  0.27649861,  0.30477688,  0.33305515,  0.317345  ,
        0.32362906,  0.39589574,  0.28278267,  0.33305515,  0.30477688,
        0.27649861,  0.39589574,  0.32362906,  0.39589574,  0.28278267,
        0.30477688,  0.33305515,  0.39589574,  0.32362906,  0.39589574,
        0.32362906,  0.32362906,  0.32362906,  0.32362906,  0.30477688,
        0.27649861,  0.30477688,  0.317345  ,  0.39589574,  0.39589574,
        0.30477688,  0.30477688,  0.32362906,  0.30477688,  0.33305515,
        0.28278267,  0.30477688,  0.28278267,  0.30477688,  0.28278267,
        0.32362906,  0.30477688,  0.28278267,  0.32362906,  0.27649861,
        0.27649861,  0.33305515,  0.30477688,  0.32362906,  0.32362906,
        0.32362906,  0.32362906,  0.27649861,  0.30477688,  0.32362906,
        0.30477688,  0.39589574,  0.33305515,  0.317345  ,  0.30477688,
        0.317345  ,  0.30477688,  0.28278267,  0.32362906,  0.39589574,
        0.28278267,  0.28278267,  0.28278267,  0.32362906,  0.32362906,
        0.32362906,  0.33305515,  0.39589574,  0.30477688,  0.28278267,
        0.33305515,  0.39589574,  0.39589574,  0.33305515,  0.27649861,
        0.30477688,  0.28278267,  0.30477688,  0.27649861,  0.39589574,
        0.30477688,  0.33305515,  0.32362906,  0.33305515,  0.32362906,
        0.39589574,  0.30477688,  0.30477688,  0.317345  ,  0.39589574,
        0.39589574,  0.32362906,  0.28278267,  0.28278267,  0.30477688,
        0.33305515,  0.28278267,  0.39589574,  0.33305515,  0.32362906,
        0.30477688,  0.39589574,  0.317345  ,  0.28278267,  0.33305515,
        0.28278267,  0.39589574,  0.39589574,  0.317345  ,  0.28278267,
        0.28278267,  0.33305515,  0.27649861,  0.28278267,  0.39589574,
        0.39589574,  0.30477688,  0.32362906,  0.32362906,  0.28278267,
        0.32362906,  0.30477688,  0.33305515,  0.317345  ,  0.28278267,
        0.28278267,  0.30477688,  0.30477688,  0.30477688,  0.27649861,
        0.28278267,  0.28278267,  0.39589574,  0.28278267,  0.317345  ,
        0.32362906,  0.39589574,  0.30477688,  0.28278267,  0.32362906,
        0.32362906,  0.30477688,  0.28278267,  0.317345  ,  0.317345  ,
        0.28278267,  0.30477688,  0.30477688,  0.32362906,  0.33305515,
        0.28278267,  0.32362906,  0.30477688,  0.32362906,  0.32362906,
        0.28278267,  0.28278267,  0.30477688,  0.27649861,  0.30477688,
        0.28278267,  0.39589574,  0.32362906,  0.28278267,  0.32362906,
        0.30477688,  0.30477688,  0.30477688,  0.28278267,  0.32362906,
        0.30477688,  0.32362906,  0.28278267,  0.32362906,  0.30477688,
        0.27649861,  0.30477688,  0.32362906,  0.32362906,  0.30477688,
        0.27649861,  0.33305515,  0.30477688,  0.317345  ,  0.32362906,
        0.28278267,  0.27649861,  0.28278267,  0.32362906,  0.317345  ,
        0.317345  ,  0.27649861,  0.28278267,  0.30477688,  0.32362906,
        0.33305515,  0.317345  ,  0.28278267,  0.30477688,  0.28278267,
        0.28278267,  0.32362906,  0.27649861,  0.28278267,  0.33305515,
        0.28278267,  0.32362906,  0.317345  ,  0.32362906,  0.39589574,
        0.30477688,  0.32362906,  0.39589574,  0.28278267,  0.32362906,
        0.28278267,  0.30477688,  0.30477688,  0.28278267,  0.33305515,
        0.39589574,  0.39589574,  0.317345  ,  0.28278267,  0.32362906,
        0.28278267,  0.28278267,  0.30477688,  0.39589574,  0.32362906,
        0.30477688,  0.27649861,  0.30477688,  0.28278267,  0.27649861,
        0.317345  ,  0.28278267,  0.27649861,  0.28278267,  0.27649861,
        0.28278267,  0.39589574,  0.32362906,  0.28278267,  0.30477688,
        0.317345  ,  0.39589574,  0.32362906,  0.33305515,  0.30477688,
        0.28278267,  0.30477688,  0.33305515,  0.28278267,  0.32362906,
        0.27649861,  0.27649861,  0.28278267,  0.33305515,  0.30477688,
        0.28278267,  0.32362906,  0.28278267,  0.30477688,  0.30477688,
        0.30477688,  0.33305515,  0.27649861,  0.32362906,  0.30477688,
        0.28278267,  0.39589574,  0.27649861,  0.32362906,  0.32362906,
        0.28278267,  0.30477688,  0.33305515,  0.39589574,  0.317345  ,
        0.32362906,  0.317345  ,  0.39589574,  0.32362906,  0.32362906,
        0.30477688,  0.28278267,  0.28278267,  0.28278267,  0.27649861,
        0.317345  ,  0.39589574,  0.32362906,  0.317345  ,  0.30477688,
        0.32362906,  0.32362906,  0.317345  ,  0.39589574,  0.30477688,
        0.39589574,  0.32362906,  0.30477688,  0.32362906,  0.28278267,
        0.32362906,  0.317345  ,  0.32362906,  0.32362906,  0.32362906,
        0.317345  ,  0.317345  ,  0.28278267,  0.32362906,  0.317345  ,
        0.32362906,  0.39589574,  0.28278267,  0.27649861,  0.27649861,
        0.33305515,  0.30477688,  0.27649861,  0.28278267,  0.28278267,
        0.32362906,  0.30477688,  0.28278267,  0.33305515,  0.30477688,
        0.30477688,  0.39589574,  0.32362906,  0.32362906,  0.32362906,
        0.32362906,  0.30477688,  0.32362906,  0.33305515,  0.33305515,
        0.33305515,  0.28278267,  0.28278267,  0.28278267,  0.27649861,
        0.33305515,  0.33305515,  0.30477688,  0.28278267,  0.30477688,
        0.317345  ,  0.317345  ,  0.30477688,  0.33305515,  0.32362906,
        0.317345  ,  0.27649861,  0.27649861,  0.27649861,  0.28278267,
        0.33305515,  0.30477688,  0.317345  ,  0.30477688,  0.39589574,
        0.39589574,  0.317345  ,  0.28278267,  0.28278267,  0.32362906,
        0.317345  ,  0.28278267,  0.32362906,  0.30477688,  0.317345  ,
        0.39589574,  0.28278267,  0.28278267,  0.33305515,  0.33305515,
        0.27649861,  0.30477688,  0.32362906,  0.317345  ,  0.33305515,
        0.39589574,  0.32362906,  0.317345  ,  0.32362906,  0.32362906,
        0.32362906,  0.32362906,  0.30477688,  0.39589574,  0.32362906,
        0.32362906,  0.39589574,  0.39589574,  0.33305515,  0.27649861,
        0.317345  ,  0.28278267,  0.39589574,  0.30477688,  0.28278267,
        0.32362906,  0.39589574,  0.28278267,  0.28278267,  0.30477688,
        0.28278267,  0.28278267,  0.32362906,  0.27649861,  0.39589574,
        0.32362906,  0.28278267,  0.32362906,  0.30477688,  0.28278267,
        0.32362906,  0.30477688,  0.33305515,  0.27649861,  0.27649861,
        0.28278267,  0.28278267,  0.317345  ,  0.28278267,  0.32362906,
        0.30477688,  0.30477688,  0.28278267,  0.27649861,  0.28278267,
        0.30477688,  0.28278267,  0.27649861,  0.32362906,  0.30477688,
        0.28278267,  0.33305515,  0.28278267,  0.28278267,  0.30477688,
        0.28278267,  0.317345  ,  0.33305515,  0.39589574,  0.27649861,
        0.30477688,  0.30477688,  0.33305515,  0.30477688,  0.317345  ,
        0.32362906,  0.30477688,  0.30477688,  0.32362906,  0.30477688,
        0.317345  ,  0.30477688,  0.32362906,  0.39589574,  0.30477688,
        0.28278267,  0.28278267,  0.33305515,  0.32362906,  0.32362906,
        0.28278267,  0.30477688,  0.39589574,  0.30477688,  0.317345  ,
        0.32362906,  0.39589574,  0.30477688,  0.30477688,  0.28278267,
        0.32362906,  0.30477688,  0.33305515,  0.33305515,  0.30477688,
        0.33305515,  0.27649861,  0.30477688,  0.28278267,  0.27649861,
        0.28278267,  0.32362906,  0.39589574,  0.27649861,  0.28278267,
        0.32362906,  0.33305515,  0.317345  ,  0.27649861,  0.30477688,
        0.317345  ,  0.27649861,  0.33305515,  0.317345  ,  0.27649861])