In [1]:
import numpy as np
from layer import *
X_xor = np.array([[x, y] for x in range(2) for y in range(2)]) T_xor = np.array([[x ^ y] for x, y in X_xor])
mdl_xor1 = Model([ FullyConnectedLayer(2, 3), Sigmoid(), FullyConnectedLayer(3, 1), Sigmoid() ], 0.01, 0)
for _ in range(100000): mdl_xor1.train(X_xor, T_xor)
mdl_xor1.test(X_xor, T_xor)
In [2]:
def cossim(o, t):
return (o * t).sum() / (np.sqrt((o ** 2).sum()) * np.sqrt((t ** 2).sum()))
In [3]:
from main import *
In [14]:
mdl = Model([
Flatten([-1, 32, 32, 3], [-1, 32 * 32 * 3]),
FullyConnectedLayer(32 * 32 * 3, 96),
Sigmoid(),
FullyConnectedLayer(96, 10),
SoftMax(is_output=True),
], 0.001, 0, 1000)
In [16]:
for i in range(1):
mdl.train(X, T)
print(i, np.array([cossim(x.round(3), t) for x, t in zip(mdl.layers[-1].val, T)]).mean())
/home/pluvian/code/eps/cnn-py/layer.py:45: RuntimeWarning: overflow encountered in exp
self.val = 1 / (1 + np.exp(-x))
0 0.31503910008
In [17]:
np.array([cossim(x.round(3), t) for x, t in zip(mdl.layers[-1].val, T)])
Out[17]:
array([ 0.28278267, 0.32362906, 0.32362906, 0.30477688, 0.33305515,
0.33305515, 0.32362906, 0.30477688, 0.39589574, 0.317345 ,
0.30477688, 0.30477688, 0.30477688, 0.32362906, 0.32362906,
0.32362906, 0.32362906, 0.317345 , 0.32362906, 0.28278267,
0.30477688, 0.317345 , 0.28278267, 0.28278267, 0.32362906,
0.28278267, 0.317345 , 0.27649861, 0.30477688, 0.28278267,
0.28278267, 0.32362906, 0.33305515, 0.317345 , 0.30477688,
0.28278267, 0.317345 , 0.30477688, 0.317345 , 0.317345 ,
0.27649861, 0.32362906, 0.32362906, 0.30477688, 0.33305515,
0.33305515, 0.33305515, 0.32362906, 0.32362906, 0.28278267,
0.32362906, 0.27649861, 0.30477688, 0.32362906, 0.32362906,
0.32362906, 0.27649861, 0.32362906, 0.30477688, 0.317345 ,
0.33305515, 0.33305515, 0.39589574, 0.32362906, 0.33305515,
0.33305515, 0.30477688, 0.32362906, 0.30477688, 0.39589574,
0.27649861, 0.32362906, 0.28278267, 0.30477688, 0.317345 ,
0.33305515, 0.32362906, 0.28278267, 0.317345 , 0.33305515,
0.317345 , 0.27649861, 0.30477688, 0.27649861, 0.30477688,
0.30477688, 0.30477688, 0.30477688, 0.32362906, 0.30477688,
0.32362906, 0.317345 , 0.39589574, 0.28278267, 0.33305515,
0.28278267, 0.33305515, 0.33305515, 0.30477688, 0.33305515,
0.39589574, 0.317345 , 0.32362906, 0.28278267, 0.28278267,
0.33305515, 0.39589574, 0.27649861, 0.32362906, 0.32362906,
0.32362906, 0.39589574, 0.33305515, 0.30477688, 0.30477688,
0.28278267, 0.28278267, 0.28278267, 0.32362906, 0.33305515,
0.32362906, 0.32362906, 0.32362906, 0.32362906, 0.28278267,
0.28278267, 0.33305515, 0.32362906, 0.27649861, 0.28278267,
0.30477688, 0.30477688, 0.28278267, 0.30477688, 0.33305515,
0.39589574, 0.33305515, 0.33305515, 0.32362906, 0.39589574,
0.33305515, 0.317345 , 0.317345 , 0.28278267, 0.32362906,
0.30477688, 0.32362906, 0.32362906, 0.27649861, 0.30477688,
0.317345 , 0.28278267, 0.30477688, 0.30477688, 0.28278267,
0.39589574, 0.27649861, 0.27649861, 0.30477688, 0.317345 ,
0.33305515, 0.39589574, 0.30477688, 0.30477688, 0.28278267,
0.28278267, 0.32362906, 0.27649861, 0.33305515, 0.317345 ,
0.39589574, 0.32362906, 0.30477688, 0.27649861, 0.317345 ,
0.30477688, 0.33305515, 0.27649861, 0.30477688, 0.28278267,
0.30477688, 0.30477688, 0.27649861, 0.27649861, 0.33305515,
0.28278267, 0.32362906, 0.28278267, 0.32362906, 0.28278267,
0.39589574, 0.30477688, 0.39589574, 0.39589574, 0.32362906,
0.27649861, 0.32362906, 0.317345 , 0.27649861, 0.28278267,
0.28278267, 0.33305515, 0.32362906, 0.317345 , 0.28278267,
0.32362906, 0.33305515, 0.317345 , 0.32362906, 0.28278267,
0.28278267, 0.30477688, 0.33305515, 0.28278267, 0.32362906,
0.27649861, 0.39589574, 0.27649861, 0.32362906, 0.32362906,
0.28278267, 0.39589574, 0.39589574, 0.28278267, 0.28278267,
0.32362906, 0.33305515, 0.33305515, 0.28278267, 0.317345 ,
0.30477688, 0.28278267, 0.28278267, 0.28278267, 0.28278267,
0.28278267, 0.33305515, 0.30477688, 0.33305515, 0.27649861,
0.39589574, 0.317345 , 0.28278267, 0.28278267, 0.39589574,
0.28278267, 0.39589574, 0.30477688, 0.28278267, 0.28278267,
0.33305515, 0.317345 , 0.39589574, 0.317345 , 0.30477688,
0.33305515, 0.30477688, 0.33305515, 0.317345 , 0.39589574,
0.27649861, 0.33305515, 0.33305515, 0.30477688, 0.28278267,
0.32362906, 0.317345 , 0.30477688, 0.30477688, 0.32362906,
0.32362906, 0.32362906, 0.30477688, 0.32362906, 0.32362906,
0.33305515, 0.28278267, 0.27649861, 0.32362906, 0.28278267,
0.39589574, 0.32362906, 0.33305515, 0.32362906, 0.28278267,
0.27649861, 0.28278267, 0.317345 , 0.32362906, 0.30477688,
0.39589574, 0.39589574, 0.28278267, 0.28278267, 0.30477688,
0.32362906, 0.30477688, 0.27649861, 0.28278267, 0.30477688,
0.32362906, 0.33305515, 0.33305515, 0.32362906, 0.33305515,
0.27649861, 0.32362906, 0.32362906, 0.28278267, 0.39589574,
0.30477688, 0.33305515, 0.33305515, 0.28278267, 0.317345 ,
0.317345 , 0.32362906, 0.28278267, 0.30477688, 0.32362906,
0.30477688, 0.30477688, 0.32362906, 0.33305515, 0.27649861,
0.33305515, 0.28278267, 0.28278267, 0.39589574, 0.30477688,
0.33305515, 0.317345 , 0.28278267, 0.317345 , 0.317345 ,
0.32362906, 0.30477688, 0.27649861, 0.30477688, 0.27649861,
0.32362906, 0.28278267, 0.317345 , 0.30477688, 0.28278267,
0.30477688, 0.30477688, 0.28278267, 0.28278267, 0.28278267,
0.28278267, 0.28278267, 0.28278267, 0.39589574, 0.33305515,
0.28278267, 0.32362906, 0.32362906, 0.32362906, 0.27649861,
0.32362906, 0.28278267, 0.30477688, 0.30477688, 0.33305515,
0.39589574, 0.30477688, 0.317345 , 0.28278267, 0.32362906,
0.317345 , 0.28278267, 0.30477688, 0.28278267, 0.27649861,
0.33305515, 0.28278267, 0.317345 , 0.30477688, 0.39589574,
0.27649861, 0.30477688, 0.30477688, 0.32362906, 0.317345 ,
0.32362906, 0.30477688, 0.28278267, 0.30477688, 0.33305515,
0.30477688, 0.30477688, 0.28278267, 0.33305515, 0.30477688,
0.317345 , 0.33305515, 0.39589574, 0.30477688, 0.30477688,
0.32362906, 0.28278267, 0.32362906, 0.32362906, 0.28278267,
0.28278267, 0.32362906, 0.28278267, 0.32362906, 0.28278267,
0.39589574, 0.32362906, 0.30477688, 0.30477688, 0.30477688,
0.28278267, 0.317345 , 0.39310593, 0.39589574, 0.32362906,
0.30477688, 0.32362906, 0.30477688, 0.32362906, 0.27649861,
0.32362906, 0.27649861, 0.33305515, 0.32362906, 0.30477688,
0.39589574, 0.27649861, 0.33305515, 0.30477688, 0.30477688,
0.30477688, 0.28278267, 0.28278267, 0.32362906, 0.28278267,
0.30477688, 0.39589574, 0.39589574, 0.32362906, 0.32362906,
0.317345 , 0.317345 , 0.30477688, 0.28278267, 0.30477688,
0.27649861, 0.28278267, 0.28278267, 0.28278267, 0.33305515,
0.28278267, 0.39589574, 0.28278267, 0.30477688, 0.39589574,
0.39589574, 0.33305515, 0.27649861, 0.32362906, 0.28278267,
0.39589574, 0.33305515, 0.28278267, 0.28278267, 0.30477688,
0.30477688, 0.27649861, 0.32362906, 0.28278267, 0.32362906,
0.39589574, 0.317345 , 0.30477688, 0.30477688, 0.317345 ,
0.32362906, 0.28278267, 0.33305515, 0.32362906, 0.30477688,
0.39589574, 0.33305515, 0.39589574, 0.28278267, 0.30477688,
0.30477688, 0.27649861, 0.30477688, 0.33305515, 0.317345 ,
0.32362906, 0.39589574, 0.28278267, 0.33305515, 0.30477688,
0.27649861, 0.39589574, 0.32362906, 0.39589574, 0.28278267,
0.30477688, 0.33305515, 0.39589574, 0.32362906, 0.39589574,
0.32362906, 0.32362906, 0.32362906, 0.32362906, 0.30477688,
0.27649861, 0.30477688, 0.317345 , 0.39589574, 0.39589574,
0.30477688, 0.30477688, 0.32362906, 0.30477688, 0.33305515,
0.28278267, 0.30477688, 0.28278267, 0.30477688, 0.28278267,
0.32362906, 0.30477688, 0.28278267, 0.32362906, 0.27649861,
0.27649861, 0.33305515, 0.30477688, 0.32362906, 0.32362906,
0.32362906, 0.32362906, 0.27649861, 0.30477688, 0.32362906,
0.30477688, 0.39589574, 0.33305515, 0.317345 , 0.30477688,
0.317345 , 0.30477688, 0.28278267, 0.32362906, 0.39589574,
0.28278267, 0.28278267, 0.28278267, 0.32362906, 0.32362906,
0.32362906, 0.33305515, 0.39589574, 0.30477688, 0.28278267,
0.33305515, 0.39589574, 0.39589574, 0.33305515, 0.27649861,
0.30477688, 0.28278267, 0.30477688, 0.27649861, 0.39589574,
0.30477688, 0.33305515, 0.32362906, 0.33305515, 0.32362906,
0.39589574, 0.30477688, 0.30477688, 0.317345 , 0.39589574,
0.39589574, 0.32362906, 0.28278267, 0.28278267, 0.30477688,
0.33305515, 0.28278267, 0.39589574, 0.33305515, 0.32362906,
0.30477688, 0.39589574, 0.317345 , 0.28278267, 0.33305515,
0.28278267, 0.39589574, 0.39589574, 0.317345 , 0.28278267,
0.28278267, 0.33305515, 0.27649861, 0.28278267, 0.39589574,
0.39589574, 0.30477688, 0.32362906, 0.32362906, 0.28278267,
0.32362906, 0.30477688, 0.33305515, 0.317345 , 0.28278267,
0.28278267, 0.30477688, 0.30477688, 0.30477688, 0.27649861,
0.28278267, 0.28278267, 0.39589574, 0.28278267, 0.317345 ,
0.32362906, 0.39589574, 0.30477688, 0.28278267, 0.32362906,
0.32362906, 0.30477688, 0.28278267, 0.317345 , 0.317345 ,
0.28278267, 0.30477688, 0.30477688, 0.32362906, 0.33305515,
0.28278267, 0.32362906, 0.30477688, 0.32362906, 0.32362906,
0.28278267, 0.28278267, 0.30477688, 0.27649861, 0.30477688,
0.28278267, 0.39589574, 0.32362906, 0.28278267, 0.32362906,
0.30477688, 0.30477688, 0.30477688, 0.28278267, 0.32362906,
0.30477688, 0.32362906, 0.28278267, 0.32362906, 0.30477688,
0.27649861, 0.30477688, 0.32362906, 0.32362906, 0.30477688,
0.27649861, 0.33305515, 0.30477688, 0.317345 , 0.32362906,
0.28278267, 0.27649861, 0.28278267, 0.32362906, 0.317345 ,
0.317345 , 0.27649861, 0.28278267, 0.30477688, 0.32362906,
0.33305515, 0.317345 , 0.28278267, 0.30477688, 0.28278267,
0.28278267, 0.32362906, 0.27649861, 0.28278267, 0.33305515,
0.28278267, 0.32362906, 0.317345 , 0.32362906, 0.39589574,
0.30477688, 0.32362906, 0.39589574, 0.28278267, 0.32362906,
0.28278267, 0.30477688, 0.30477688, 0.28278267, 0.33305515,
0.39589574, 0.39589574, 0.317345 , 0.28278267, 0.32362906,
0.28278267, 0.28278267, 0.30477688, 0.39589574, 0.32362906,
0.30477688, 0.27649861, 0.30477688, 0.28278267, 0.27649861,
0.317345 , 0.28278267, 0.27649861, 0.28278267, 0.27649861,
0.28278267, 0.39589574, 0.32362906, 0.28278267, 0.30477688,
0.317345 , 0.39589574, 0.32362906, 0.33305515, 0.30477688,
0.28278267, 0.30477688, 0.33305515, 0.28278267, 0.32362906,
0.27649861, 0.27649861, 0.28278267, 0.33305515, 0.30477688,
0.28278267, 0.32362906, 0.28278267, 0.30477688, 0.30477688,
0.30477688, 0.33305515, 0.27649861, 0.32362906, 0.30477688,
0.28278267, 0.39589574, 0.27649861, 0.32362906, 0.32362906,
0.28278267, 0.30477688, 0.33305515, 0.39589574, 0.317345 ,
0.32362906, 0.317345 , 0.39589574, 0.32362906, 0.32362906,
0.30477688, 0.28278267, 0.28278267, 0.28278267, 0.27649861,
0.317345 , 0.39589574, 0.32362906, 0.317345 , 0.30477688,
0.32362906, 0.32362906, 0.317345 , 0.39589574, 0.30477688,
0.39589574, 0.32362906, 0.30477688, 0.32362906, 0.28278267,
0.32362906, 0.317345 , 0.32362906, 0.32362906, 0.32362906,
0.317345 , 0.317345 , 0.28278267, 0.32362906, 0.317345 ,
0.32362906, 0.39589574, 0.28278267, 0.27649861, 0.27649861,
0.33305515, 0.30477688, 0.27649861, 0.28278267, 0.28278267,
0.32362906, 0.30477688, 0.28278267, 0.33305515, 0.30477688,
0.30477688, 0.39589574, 0.32362906, 0.32362906, 0.32362906,
0.32362906, 0.30477688, 0.32362906, 0.33305515, 0.33305515,
0.33305515, 0.28278267, 0.28278267, 0.28278267, 0.27649861,
0.33305515, 0.33305515, 0.30477688, 0.28278267, 0.30477688,
0.317345 , 0.317345 , 0.30477688, 0.33305515, 0.32362906,
0.317345 , 0.27649861, 0.27649861, 0.27649861, 0.28278267,
0.33305515, 0.30477688, 0.317345 , 0.30477688, 0.39589574,
0.39589574, 0.317345 , 0.28278267, 0.28278267, 0.32362906,
0.317345 , 0.28278267, 0.32362906, 0.30477688, 0.317345 ,
0.39589574, 0.28278267, 0.28278267, 0.33305515, 0.33305515,
0.27649861, 0.30477688, 0.32362906, 0.317345 , 0.33305515,
0.39589574, 0.32362906, 0.317345 , 0.32362906, 0.32362906,
0.32362906, 0.32362906, 0.30477688, 0.39589574, 0.32362906,
0.32362906, 0.39589574, 0.39589574, 0.33305515, 0.27649861,
0.317345 , 0.28278267, 0.39589574, 0.30477688, 0.28278267,
0.32362906, 0.39589574, 0.28278267, 0.28278267, 0.30477688,
0.28278267, 0.28278267, 0.32362906, 0.27649861, 0.39589574,
0.32362906, 0.28278267, 0.32362906, 0.30477688, 0.28278267,
0.32362906, 0.30477688, 0.33305515, 0.27649861, 0.27649861,
0.28278267, 0.28278267, 0.317345 , 0.28278267, 0.32362906,
0.30477688, 0.30477688, 0.28278267, 0.27649861, 0.28278267,
0.30477688, 0.28278267, 0.27649861, 0.32362906, 0.30477688,
0.28278267, 0.33305515, 0.28278267, 0.28278267, 0.30477688,
0.28278267, 0.317345 , 0.33305515, 0.39589574, 0.27649861,
0.30477688, 0.30477688, 0.33305515, 0.30477688, 0.317345 ,
0.32362906, 0.30477688, 0.30477688, 0.32362906, 0.30477688,
0.317345 , 0.30477688, 0.32362906, 0.39589574, 0.30477688,
0.28278267, 0.28278267, 0.33305515, 0.32362906, 0.32362906,
0.28278267, 0.30477688, 0.39589574, 0.30477688, 0.317345 ,
0.32362906, 0.39589574, 0.30477688, 0.30477688, 0.28278267,
0.32362906, 0.30477688, 0.33305515, 0.33305515, 0.30477688,
0.33305515, 0.27649861, 0.30477688, 0.28278267, 0.27649861,
0.28278267, 0.32362906, 0.39589574, 0.27649861, 0.28278267,
0.32362906, 0.33305515, 0.317345 , 0.27649861, 0.30477688,
0.317345 , 0.27649861, 0.33305515, 0.317345 , 0.27649861])
Content source: PluVian/eps2017
Similar notebooks: