In [26]:
# 啟動互動式繪圖環境
%pylab inline
from IPython.display import Image
import numpy as np
plt.style.use('ggplot')


Populating the interactive namespace from numpy and matplotlib

In [17]:
# dataset
dataset = np.array([
((-0.4, 0.3), -1),
((-0.3, -0.1), -1),
((-0.2, 0.4), -1),
((-0.1, 0.1), -1),
((0.9, -0.5), 1),
((0.7, -0.9), 1),
((0.8, 0.2), 1),
((0.4, -0.6), 1),
((0.2, -0.2), 1),
((0.2, 0.5), 1)])

In [18]:
# plot
x0 = []; y0 = []; x1 = []; y1 = []
for x in dataset:
    if x[1] == -1:
        x0.append(x[0][0])
        y0.append(x[0][1])
    elif x[1] == 1:
        x1.append(x[0][0])
        y1.append(x[0][1])
        
fig = plt.figure()
plt.scatter(x0, y0, color="DarkGreen")
plt.scatter(x1, y1, color="DarkBlue")


Out[18]:
<matplotlib.collections.PathCollection at 0x1069ceb10>

In [27]:
Image(filename='/Users/wy/Desktop/PLA.png')


Out[27]:

In [19]:
def sign(w1,w2,point):
    w = np.array([w1,w2])
    return w.T.dot(point)

In [20]:
def plot(w1,w2):
    fig = plt.figure()
    plt.scatter(x0, y0, color="DarkGreen")
    plt.scatter(x1, y1, color="DarkBlue")
    # line
    x = []; y = [];
    for num in map( lambda i: i/10., range(-10, 10) ):
        alpah = -(w2/w1)
        x.append(num*alpah)
        y.append(num)    
    plt.plot(x,y, 'r')

In [21]:
def pla(w1,w2):
    plot(w1,w2)
    break_count = 0
    while break_count == 0:
        for index in range(dataset.shape[0]):
            label = dataset[index][1]
            sign_ = sign(w1,w2,dataset[index][0])
            # update w
            if (label > 0 and sign_ <0) or (label < 0 and sign_ >0):
                break_count = 0
                w1 = w1 + dataset[index][0][0]*float(dataset[index][1])
                w2 = w2 + dataset[index][0][1]*float(dataset[index][1])
                plot(w1,w2)
                break
            else:
                break_count = 1

In [22]:
pla(w1=0.01,w2=0.02)