In [26]:
# 啟動互動式繪圖環境
%pylab inline
from IPython.display import Image
import numpy as np
plt.style.use('ggplot')
In [17]:
# dataset
dataset = np.array([
((-0.4, 0.3), -1),
((-0.3, -0.1), -1),
((-0.2, 0.4), -1),
((-0.1, 0.1), -1),
((0.9, -0.5), 1),
((0.7, -0.9), 1),
((0.8, 0.2), 1),
((0.4, -0.6), 1),
((0.2, -0.2), 1),
((0.2, 0.5), 1)])
In [18]:
# plot
x0 = []; y0 = []; x1 = []; y1 = []
for x in dataset:
if x[1] == -1:
x0.append(x[0][0])
y0.append(x[0][1])
elif x[1] == 1:
x1.append(x[0][0])
y1.append(x[0][1])
fig = plt.figure()
plt.scatter(x0, y0, color="DarkGreen")
plt.scatter(x1, y1, color="DarkBlue")
Out[18]:
In [27]:
Image(filename='/Users/wy/Desktop/PLA.png')
Out[27]:
In [19]:
def sign(w1,w2,point):
w = np.array([w1,w2])
return w.T.dot(point)
In [20]:
def plot(w1,w2):
fig = plt.figure()
plt.scatter(x0, y0, color="DarkGreen")
plt.scatter(x1, y1, color="DarkBlue")
# line
x = []; y = [];
for num in map( lambda i: i/10., range(-10, 10) ):
alpah = -(w2/w1)
x.append(num*alpah)
y.append(num)
plt.plot(x,y, 'r')
In [21]:
def pla(w1,w2):
plot(w1,w2)
break_count = 0
while break_count == 0:
for index in range(dataset.shape[0]):
label = dataset[index][1]
sign_ = sign(w1,w2,dataset[index][0])
# update w
if (label > 0 and sign_ <0) or (label < 0 and sign_ >0):
break_count = 0
w1 = w1 + dataset[index][0][0]*float(dataset[index][1])
w2 = w2 + dataset[index][0][1]*float(dataset[index][1])
plot(w1,w2)
break
else:
break_count = 1
In [22]:
pla(w1=0.01,w2=0.02)