In [ ]:
from numpy import append, arange, array, concatenate, dot, meshgrid, ones, shape, where
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
In [ ]:
def fit(X, y):
eta = .25
nIter = 1000
inputs = concatenate((-ones((shape(X)[0],1)), X), axis=1)
nData = shape(inputs)[0]
nInput = shape(inputs)[1]
nOutput = shape(inputs)[1]
weights = array([-0.05, -0.02, 0.02])
for iteration in range(nIter):
for input in range(nData):
x0 = inputs[input][0]
x1 = inputs[input][1]
x2 = inputs[input][2]
w0 = weights[0]
w1 = weights[1]
w2 = weights[2]
target = y[input]
output = where(dot(inputs[input], weights) > 0, 1, 0)
diff = target[0]-output
new_w0 = w0 + eta*(diff)*x0
new_w1 = w1 + eta*(diff)*x1
new_w2 = w2 + eta*(diff)*x2
# print("x0={}, x1={}, x2={}, w0={}, w1={}, w2={}, y={} t={}, (t-y)={}, new_w0={}, new_w1={} ,new_w2={}".format(x0, x1, x2, w0, w1, w2, output, target, diff, new_w0, new_w1, new_w2))
weights[0] = new_w0
weights[1] = new_w1
weights[2] = new_w2
# print("-"*100)
return weights
In [ ]:
def predict(X, weights):
input = concatenate((-ones((shape(X)[0],1)), X), axis=1)
result = array([])
for i in input:
result = append(result, where(dot(i, weights) > 0, 1, 0))
return result
In [ ]:
# Perceptron test
def plot(X, y, weights):
h = .02 # step size in the mesh
padding = .1
# create a mesh to plot in
x_min, x_max = X[:, 0].min() - padding, X[:, 0].max() + padding
y_min, y_max = X[:, 1].min() - padding, X[:, 1].max() + padding
xx, yy = meshgrid(arange(x_min, x_max, h),
arange(y_min, y_max, h))
# Plot the decision boundary. For that, we will assign a color to each
# point in the mesh [x_min, m_max]x[y_min, y_max].
Z = predict(np.c_[xx.ravel(), yy.ravel()], weights)
Z = Z.reshape(xx.shape)
plt.contourf(xx, yy, Z, cmap=plt.cm.Paired)
# Plot also the training points
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Paired)
plt.title("Linear Perceptron")
In [ ]:
X = array([[0.0,0.0],[0.0,1.0],[1.0,0.0],[1.0,1.0]])
y = array([[0],[1],[1],[1]])
weights = fit(X, y)
plot(X, y, weights)
In [ ]:
X1 = array([
[ -9.23227099e-01, -5.30005243e-01],
[ -1.73110085e-01, -6.46376050e-01],
[ -1.10962064e+00, 8.96151705e-01],
[ 1.33074731e+00, 9.16141463e-01],
[ 1.14573765e+00, 9.62953314e-01],
[ -7.42716676e-01, 6.93484785e-01],
[ 1.04554489e-03, 5.50906976e-01],
[ -3.42458183e-01, 1.02460594e+00],
[ -4.95877208e-02, -7.61725992e-01],
[ 4.37219633e-01, -1.27416648e+00],
[ 8.22040918e-02, -9.55345728e-01],
[ -1.38157686e+00, -1.31457659e+00],
[ -4.44716175e-01, -8.17372537e-01],
[ 1.48640152e+00, 1.57539434e+00],
[ -4.60192520e-01, -9.20862390e-01],
[ -1.15206706e+00, 1.25332205e+00],
[ -7.92688078e-01, 1.00870860e+00],
[ 1.41125886e+00, 1.41651762e+00],
[ 7.39141162e-01, -5.42935168e-01],
[ 8.12223936e-01, 7.08731757e-01],
[ -1.05594115e+00, 1.02072293e+00],
[ -1.01032284e-01, 7.99648598e-01],
[ 5.01974839e-01, 9.85599998e-01],
[ 2.57541673e-01, 1.59016619e+00],
[ 1.03164812e-02, -8.82285095e-01],
[ 4.08720815e-01, 1.21129444e+00],
[ 1.28439088e+00, 7.71001466e-01],
[ 5.93453609e-01, -1.26871190e+00],
[ 9.76173662e-01, 1.04143459e+00],
[ -1.17353929e+00, -1.43257237e+00],
[ -3.77976108e-01, 1.03985375e+00],
[ -1.48407759e+00, -1.93939436e+00],
[ 2.28896742e-01, -1.17375324e+00],
[ 1.52391855e-01, 1.32495971e+00],
[ 6.70220444e-01, -8.25508296e-01],
[ -2.08731949e-01, 1.06803342e+00],
[ 7.57544238e-01, -9.82383636e-01],
[ 1.65190811e+00, 8.53643762e-01],
[ -1.83938473e-01, -8.44148426e-01],
[ 8.25142839e-01, -1.27206664e+00],
[ -4.75204996e-01, -5.76071616e-01],
[ 3.28622643e-02, -8.21111196e-01],
[ -8.54068536e-01, 8.36181257e-01],
[ 4.04572332e-01, 6.46307932e-01],
[ 3.10424573e-01, 1.51448476e+00],
[ -6.01646368e-01, 7.58311818e-01],
[ 1.74946020e+00, -3.50220782e-01],
[ -3.95594985e-01, 1.24822740e+00],
[ -1.26467028e+00, -1.16468449e+00],
[ -1.22973181e+00, 3.81225065e-02],
[ 2.71793375e+00, -3.10921474e-01],
[ 8.97826883e-01, 9.70116705e-01],
[ -6.07984572e-01, -1.69795597e+00],
[ 1.47547972e+00, 1.06906869e+00],
[ 1.01987159e+00, 8.40878783e-01],
[ -1.54322623e+00, -1.46571861e+00],
[ 8.94530888e-01, 8.81244933e-01],
[ -7.45225387e-01, 1.08521463e+00],
[ 1.58324206e+00, 1.09298561e+00],
[ -1.56754521e+00, 8.88357627e-01],
[ 5.70834866e-01, 1.01548812e+00],
[ 1.15399623e+00, 1.14337021e+00],
[ -3.13148537e-01, 8.87938007e-01],
[ 1.55084825e+00, -1.27644000e+00],
[ -1.44965500e+00, -3.41656484e-01],
[ 1.30931321e-01, 1.16975619e+00],
[ -2.28886029e-02, 3.36663384e-01],
[ -9.71023636e-02, -5.82380543e-01],
[ -1.47254009e+00, -8.83793350e-01],
[ 1.20271765e+00, -4.02172987e-01],
[ 1.52539471e+00, -9.20498692e-01],
[ 8.35164350e-01, 1.15717157e+00],
[ -4.58132918e-01, -1.00696667e+00],
[ -8.81610141e-01, 1.20409997e+00],
[ -5.40723050e-01, -1.27656171e+00],
[ 1.22458678e+00, -3.91674583e-01],
[ -8.30445734e-01, 9.57360694e-01],
[ 1.91171975e+00, -5.62779046e-01],
[ 8.07527728e-05, 1.04153188e+00],
[ 9.74069206e-01, 1.34466079e+00],
[ -5.12342484e-02, -6.86384766e-01],
[ -4.40147702e-01, -7.49113477e-01],
[ -5.52724936e-01, -2.42377689e+00],
[ 7.87693426e-01, -2.05547254e-01],
[ 6.34709108e-01, 1.17812988e+00],
[ -1.52990783e+00, -1.48431388e+00],
[ 8.58868905e-01, 1.38018801e+00],
[ -3.99853412e-01, -1.15877875e+00],
[ -8.52843901e-01, -6.52141259e-01],
[ 1.78260416e+00, 1.04114712e+00],
[ -4.68116118e-01, -1.26957986e+00],
[ -1.26971534e+00, -1.12313560e+00],
[ -7.86969305e-01, -6.20754582e-01],
[ -3.04161075e-01, -6.89139271e-02],
[ 1.68796351e+00, 9.71052012e-01],
[ -6.30695625e-01, -7.32753883e-01],
[ -8.21548570e-01, 1.20872775e+00],
[ 1.35440574e+00, -1.24659940e+00],
[ 7.52913742e-01, 6.90800667e-01],
[ 1.24474309e+00, -3.14104974e-01]])
y1 = array([[1], [1], [0], [0], [0], [0], [0], [0], [1], [1], [1], [1],
[1], [0], [1], [0], [0], [0], [1], [0], [0], [0], [0], [0], [1], [0], [0],
[1], [0], [1], [0], [1], [1], [0], [1], [0], [1], [0], [1], [1], [1], [1],
[0], [0], [0], [0], [1], [0], [1], [1], [1], [0], [1], [0], [0], [1], [0],
[0], [0], [0], [0], [0], [0], [1], [1], [0], [0], [1], [1], [1], [1], [0],
[1], [0], [1], [1], [0], [1], [0], [0], [1], [1], [1], [1], [0], [1], [0],
[1], [1], [0], [1], [1], [1], [1], [0], [1], [0], [1], [0], [1]])
In [ ]:
weights = fit(X1, y1)
plot(X1, y1, weights)