使用寶可夢作為實驗資料素材


In [4]:
import numpy as np

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

df = pd.read_csv("./pokemon.csv")

In [3]:



Out[3]:
name species cp hp weight height power_up_stardust power_up_candy attack_weak attack_weak_type ... height_new power_up_stardust_new power_up_candy_new attack_weak_new attack_weak_type_new attack_weak_value_new attack_strong_new attack_strong_type_new attack_strong_value_new notes
0 Pidgey1 Pidgey 384 56 2.31 0.34 2500 2 Tackle Normal ... 1.24 2500 2 Steel Wing Steel 15 Air Cutter Flying 30 NaN
1 Pidgey2 Pidgey 366 54 1.67 0.29 2500 2 Quick Attack Normal ... 1.05 2500 2 Wing Attack Flying 9 Air Cutter Flying 30 NaN
2 Pidgey3 Pidgey 353 55 1.94 0.30 3000 3 Quick Attack Normal ... 1.11 3000 3 Wing Attack Flying 9 Air Cutter Flying 30 NaN
3 Pidgey4 Pidgey 338 51 1.73 0.31 3000 3 Tackle Normal ... 1.12 3000 3 Steel Wing Steel 15 Air Cutter Flying 30 NaN
4 Pidgey5 Pidgey 242 45 1.44 0.27 1900 2 Quick Attack Normal ... 0.98 1900 2 Wing Attack Flying 9 Twister Dragon 25 NaN
5 Pidgey6 Pidgey 129 35 2.07 0.35 800 1 Quick Attack Normal ... 1.27 800 1 Wing Attack Flying 9 Aerial Ace Flying 30 NaN
6 Pidgey7 Pidgey 10 10 0.92 0.25 200 1 Tackle Normal ... 0.90 200 1 Wing Attack Flying 9 Air Cutter Flying 30 NaN
7 Pidgey8 Pidgey 25 14 2.72 0.37 200 1 Tackle Normal ... 1.35 200 1 Steel Wing Steel 15 Air Cutter Flying 30 NaN
8 Pidgey9 Pidgey 24 13 2.07 0.32 200 1 Quick Attack Normal ... 1.16 200 1 Wing Attack Flying 9 Twister Dragon 25 NaN
9 Pidgey10 Pidgey 161 35 1.45 0.31 1000 1 Tackle Normal ... 1.14 1000 1 Steel Wing Steel 15 Aerial Ace Flying 30 NaN
10 Pidgey11 Pidgey 114 31 1.58 0.26 800 1 Tackle Normal ... 0.96 800 1 Wing Attack Flying 9 Aerial Ace Flying 30 NaN
11 Pidgey12 Pidgey 333 52 1.85 0.30 3000 3 Tackle Normal ... 1.10 3000 3 Steel Wing Steel 15 Aerial Ace Flying 30 NaN
12 Pidgey13 Pidgey 132 33 1.63 0.28 800 1 Quick Attack Normal ... 1.03 800 1 Wing Attack Flying 9 Air Cutter Flying 30 NaN
13 Pidgey14 Pidgey 60 21 1.67 0.30 400 1 Quick Attack Normal ... 1.09 400 1 Steel Wing Steel 15 Twister Dragon 25 NaN
14 Pidgey15 Pidgey 42 19 2.01 0.30 400 1 Quick Attack Normal ... 1.11 400 1 Wing Attack Flying 9 Aerial Ace Flying 30 NaN
15 Pidgey16 Pidgey 91 29 2.68 0.35 600 1 Quick Attack Normal ... 1.30 600 1 Steel Wing Steel 15 Air Cutter Flying 30 NaN
16 Pidgey17 Pidgey 139 34 1.76 0.31 1000 1 Tackle Normal ... 1.13 1000 1 Steel Wing Steel 15 Aerial Ace Flying 30 NaN
17 Pidgey18 Pidgey 330 48 1.62 0.29 2500 2 Quick Attack Normal ... 1.08 2500 2 Wing Attack Flying 9 Aerial Ace Flying 30 weight_new is accurate
18 Pidgey19 Pidgey 328 48 1.62 0.30 2500 2 Quick Attack Normal ... 1.08 2500 2 Steel Wing Steel 15 Aerial Ace Flying 30 NaN
19 Pidgey20 Pidgey 176 35 1.81 0.29 1000 1 Tackle Normal ... 1.07 1000 1 Steel Wing Steel 15 Twister Dragon 25 NaN
20 Pidgey21 Pidgey 97 30 1.95 0.31 600 1 Tackle Normal ... 1.15 600 1 Steel Wing Steel 15 Aerial Ace Flying 30 NaN
21 Pidgey22 Pidgey 74 24 0.78 0.26 600 1 Tackle Normal ... 0.95 600 1 Steel Wing Steel 15 Twister Dragon 25 NaN
22 Pidgey23 Pidgey 127 32 2.33 0.34 800 1 Tackle Normal ... 1.23 800 1 Steel Wing Steel 15 Air Cutter Flying 30 NaN
23 Pidgey24 Pidgey 78 26 2.11 0.32 600 1 Tackle Normal ... 1.17 600 1 Steel Wing Steel 15 Aerial Ace Flying 30 NaN
24 Pidgey25 Pidgey 240 41 2.53 0.33 1600 2 Tackle Normal ... 1.21 1600 2 Wing Attack Flying 9 Air Cutter Flying 30 NaN
25 Pidgey26 Pidgey 276 48 0.87 0.24 2200 2 Tackle Normal ... 0.86 2200 2 Steel Wing Steel 15 Aerial Ace Flying 30 NaN
26 Pidgey27 Pidgey 207 41 1.19 0.26 1600 2 Quick Attack Normal ... 0.96 1600 2 Steel Wing Steel 15 Air Cutter Flying 30 NaN
27 Pidgey28 Pidgey 176 35 1.68 0.31 1300 2 Tackle Normal ... 1.13 1300 2 Steel Wing Steel 15 Aerial Ace Flying 30 NaN
28 Pidgey29 Pidgey 316 47 1.71 0.27 2500 2 Quick Attack Normal ... 0.99 2500 2 Wing Attack Flying 9 Twister Dragon 25 NaN
29 Pidgey30 Pidgey 305 53 2.04 0.32 2200 2 Quick Attack Normal ... 1.17 2200 2 Steel Wing Steel 15 Aerial Ace Flying 30 NaN
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
45 Weedle7 Weedle 45 27 2.55 0.29 600 1 Bug Bite Bug ... 0.58 600 1 Bug Bite Bug 5 Struggle Normal 15 NaN
46 Weedle8 Weedle 129 45 4.01 0.34 1600 2 Bug Bite Bug ... 0.69 1600 2 Bug Bite Bug 5 Struggle Normal 15 NaN
47 Weedle9 Weedle 200 49 2.62 0.27 3000 3 Poison Sting Poison ... 0.55 3000 3 Bug Bite Bug 5 Struggle Normal 15 NaN
48 Weedle10 Weedle 179 47 3.95 0.35 2500 2 Bug Bite Bug ... 0.70 2500 2 Poison Sting Poison 6 Struggle Normal 15 NaN
49 Weedle11 Weedle 104 36 3.48 0.31 1300 2 Poison Sting Poison ... 0.61 1300 2 Poison Sting Poison 6 Struggle Normal 15 NaN
50 Weedle12 Weedle 144 48 2.60 0.29 1900 2 Poison Sting Poison ... 0.57 1900 2 Bug Bite Bug 5 Struggle Normal 15 NaN
51 Weedle13 Weedle 157 49 2.05 0.28 2200 2 Bug Bite Bug ... 0.56 2200 2 Poison Sting Poison 6 Struggle Normal 15 NaN
52 Weedle14 Weedle 106 39 5.31 0.39 1300 2 Bug Bite Bug ... 0.77 1300 2 Poison Sting Poison 6 Struggle Normal 15 NaN
53 Weedle15 Weedle 203 52 3.54 0.30 3000 3 Bug Bite Bug ... 0.59 3000 3 Poison Sting Poison 6 Struggle Normal 15 NaN
54 Weedle16 Weedle 16 13 2.68 0.26 200 1 Bug Bite Bug ... 0.52 200 1 Bug Bite Bug 5 Struggle Normal 15 NaN
55 Weedle17 Weedle 240 54 4.33 0.35 2500 2 Poison Sting Poison ... 0.71 2500 2 Bug Bite Bug 5 Struggle Normal 15 NaN
56 Weedle18 Weedle 60 27 3.46 0.32 800 1 Poison Sting Poison ... 0.63 800 1 Bug Bite Bug 5 Struggle Normal 15 NaN
57 Weedle19 Weedle 80 35 1.69 0.22 800 1 Poison Sting Poison ... 0.44 800 1 Poison Sting Poison 6 Struggle Normal 15 NaN
58 Weedle20 Weedle 75 30 3.52 0.35 800 1 Bug Bite Bug ... 0.70 800 1 Bug Bite Bug 5 Struggle Normal 15 NaN
59 Caterpie1 Caterpie 136 44 4.68 0.35 1600 2 Bug Bite Bug ... 0.82 1600 2 Tackle Normal 12 Struggle Normal 15 NaN
60 Caterpie2 Caterpie 231 62 3.80 0.32 2500 2 Bug Bite Bug ... 0.75 2500 2 Tackle Normal 12 Struggle Normal 15 NaN
61 Caterpie3 Caterpie 43 26 2.68 0.28 600 1 Bug Bite Bug ... 0.65 600 1 Tackle Normal 12 Struggle Normal 15 NaN
62 Caterpie4 Caterpie 140 49 2.70 0.30 1900 2 Tackle Normal ... 0.71 1900 2 Bug Bite Bug 5 Struggle Normal 15 NaN
63 Caterpie5 Caterpie 208 53 2.24 0.29 2500 2 Tackle Normal ... 0.67 2500 2 Tackle Normal 12 Struggle Normal 15 NaN
64 Caterpie6 Caterpie 128 46 1.69 0.24 1600 2 Tackle Normal ... 0.57 1600 2 Bug Bite Bug 5 Struggle Normal 15 NaN
65 Caterpie7 Caterpie 186 56 3.14 0.27 3000 3 Tackle Normal ... 0.64 3000 3 Tackle Normal 12 Struggle Normal 15 NaN
66 Caterpie8 Caterpie 149 46 2.69 0.31 1900 2 Bug Bite Bug ... 0.72 1900 2 Tackle Normal 12 Struggle Normal 15 NaN
67 Caterpie9 Caterpie 10 10 1.87 0.29 200 1 Bug Bite Bug ... 0.69 200 1 Tackle Normal 12 Struggle Normal 15 NaN
68 Caterpie10 Caterpie 86 10 3.50 0.30 1000 1 Tackle Normal ... 0.70 1000 1 Bug Bite Bug 5 Struggle Normal 15 NaN
69 Eevee1 Eevee 619 74 2.87 0.20 3000 3 Quick Attack Normal ... 0.68 3000 3 Water Gun Water 6 Aqua Tail Water 45 Vaporean, attack details added on date later t...
70 Eevee3 Eevee 500 64 5.11 0.26 2200 2 Tackle Normal ... 0.77 2200 2 Ember Fire 10 Fire Blast Fire 100 Jolteon, attack details added on date later th...
71 Eevee7 Eevee 606 74 10.42 0.38 2500 2 Tackle Normal ... 1.26 2500 2 Water Gun Water 6 Aqua Tail Water 45 Vaporeon, attack details added on date later t...
72 Eevee8 Eevee 548 66 6.87 0.32 2500 2 Tackle Normal ... 0.95 2500 2 Ember Fire 10 Heat Wave Fire 80 Flareon, attack details added on date later th...
73 Eevee9 Eevee 528 66 7.43 0.31 2200 2 Quick Attack Normal ... 0.82 2200 2 Thunder Shock Electric 5 Thunderbolt Electric 55 Jolteon, attack details added on date later th...
74 Eevee10 Eevee 517 72 6.63 0.29 2500 2 Tackle Normal ... 0.88 2500 2 Ember Fire 10 Heat Wave Fire 80 Flareon, attack details added on date later th...

75 rows × 27 columns


In [5]:
alldata = df[ (df['attack_strong_type'] == 'Normal') |(df['attack_strong_type'] == 'Flying')  ]
# alldata = alldata[df['attack_strong_type'] == 'Flying']

In [6]:
f1 = alldata['height'].tolist()
f2 = alldata['weight'].tolist()
y = alldata['attack_strong_type']=='Normal'
y = [ 1 if i  else 0  for i in y.tolist()]


f1 = np.array(f1)
f2 = np.array(f2)

In [7]:
c = [ 'g' if i==1 else 'b' for i in y ]

plt.scatter(f1, f2, 20, c=c, alpha=0.5,
            label="Type")
plt.xlabel("Height")
plt.ylabel("Weight")
plt.legend(loc=2)
plt.show()


對原始資料進行 Sacle to 0~1 之間,避免 weight 太劇烈變動,導至 learning rate 難以設定


In [8]:
from sklearn.preprocessing import scale,MinMaxScaler
scaler1 = MinMaxScaler()
scaler2 = MinMaxScaler()

f1 = f2.reshape([f1.shape[0],1])
f2 = f2.reshape([f2.shape[0],1])



scaler1.fit(f1)
scaler2.fit(f2)

f1 = scaler1.transform(f1)
f2 = scaler2.transform(f2)

In [9]:
f1 = f1.reshape(f1.shape[0])
f2 = f2.reshape(f2.shape[0])

In [10]:
c = [ 'g' if i==1 else 'b' for i in y ]

plt.scatter(f1, f2, 20, c=c, alpha=0.5,
            label="Type")
plt.xlabel("Height")
plt.ylabel("Weight")
plt.legend(loc=2)
plt.show()



In [11]:
Y = np.array([1,1,0,0,1])
A = np.array([0.8, 0.7, 0.2, 0.1, 0.9])
A2 = np.array([0.6, 0.6, 0.2, 0.1, 0.3])

def cross_entropy(Y,A):
    
    # small tip 因 log(0) 會趨近負無限大,會產生 nan ,故這裡統一加上 0.00001
    Y = np.array(Y)
    A = np.array(A)
    m = len(A)
    cost = -(1.0/m) * np.sum(Y*np.log(A+0.00001) + (1-Y)*np.log(1-A+0.00001))
    return cost
# Test cross_entropy Function
print cross_entropy(Y,A)
print cross_entropy(Y,A2)
print cross_entropy(Y,Y)


  File "<ipython-input-11-932cbe4c2daa>", line 14
    print cross_entropy(Y,A)
                      ^
SyntaxError: invalid syntax

LogisticRegression 公式推導如下

推導過程請參考 : http://speech.ee.ntu.edu.tw/~tlkagk/courses/ML_2017/Lecture/Logistic%20Regression%20(v4).pdf 中的 3~13 頁,其重點步驟如下

  • 令 $y=sigmoid(w*x+b)=f_{w,b}$
  • 我們想要 Maxima右邊這個式之 => $ ArgMaxL_{w,b}= \prod\left( \hat y*f_{w,b} + (1-\hat y)*(1-f_{w,b}) \right) $
  • 但如果我們要看成 Lose Function 反過加上個負號並取 Log 方便計算,所以式之會變成如下
  • $ ArgMinL_{w,b}= -1*ln\sum\left( \hat y*f_{w,b} + (1-\hat y)*(1-f_{w,b}) \right) $
  • 然後對述的式子,分別對 $w,b$ 作偏微分,以更取得到 $\Delta w$及$\Delta b$
  • 然後再乘上 Learning Rate 進行更新如式 $w_{t+1}=w_t - r*\Delta w $ , $b_{t+1}=b_t - r*\Delta b $

最終推導結果如下

  • $w_i$ 的 update 公式如下 , $\hat y$ 為 traning data 的 target label, $x^n$ 為第 n 個 data 的值
  • $w_{t+1} = w_t - r*\sum\left((\hat y^n - f_{w,b}(x^n))*x^n*-1\right) $
  • $b_i$ 的 update 與 $w_i$ 只差了一項就是不用乘上 $X^n$ ,如右 $b_{t+1} = b_t - r*\sum\left(-1*(\hat y^n - f_{w,b}(x^n))\right) $

In [16]:
import math

w1 =1
w2 =1
b = 0
r = 0.001
def fx(x1,x2):
    temp = w1*x1 + w2*x2 + b
    y_head = 1. / (1. + math.exp(-1.*temp))
    return y_head

In [17]:
def cross_entropy(Y, A):
    m = len(A)
    cost = -(1.0 / m) * np.sum(Y * np.log(A) + (1 - Y) * np.log(1 - A))
    return cost

for i in range(10000):
    w1_delta=0
    w2_delta=0

    b_delta = 0

    y_error = 0 
    
    
    for x1,x2,y_now in zip(f1,f2,y):
        
        y_error = y_now - fx(x1,x2)
        

        w1_delta = -1*x1*y_error
        w2_delta = -1*x2*y_error

        b_delta = -1*y_error

        w1 -= r*w1_delta
        w2 -= r*w2_delta

        b -= r*b_delta



    
    if i % 100==0 : 
        error_rate = 0
        y_predict = []
        
        for x1,x2,y_now in zip(f1,f2,y):            
            y_predict.append(fx(x1,x2))
            if y_now==1 and fx(x1,x2) < 0.5:
                error_rate+=1
            elif y_now==0 and fx(x1,x2) >=0.5:
                error_rate+=1

        print("{:0,.3f}, {:0,.3f}, {:0,.3f}, {:0,.3f}, {:0,.3f}".format(error_rate*1./len(y) ,cross_entropy(np.array(y),np.array(y_predict)),w1,w2,b)  )


0.397, 0.598, 1.002, 1.002, -0.001
0.379, 0.586, 1.184, 1.184, -0.070
0.345, 0.575, 1.357, 1.357, -0.138
0.328, 0.565, 1.522, 1.522, -0.203
0.310, 0.555, 1.681, 1.681, -0.264
0.293, 0.547, 1.834, 1.834, -0.322
0.259, 0.539, 1.982, 1.982, -0.378
0.241, 0.531, 2.125, 2.125, -0.431
0.241, 0.524, 2.263, 2.263, -0.481
0.241, 0.517, 2.398, 2.398, -0.530
0.224, 0.511, 2.528, 2.528, -0.577
0.207, 0.505, 2.654, 2.654, -0.623
0.190, 0.500, 2.777, 2.777, -0.667
0.190, 0.495, 2.897, 2.897, -0.709
0.190, 0.490, 3.014, 3.014, -0.750
0.172, 0.485, 3.128, 3.128, -0.790
0.172, 0.480, 3.239, 3.239, -0.828
0.172, 0.476, 3.348, 3.348, -0.866
0.155, 0.472, 3.454, 3.454, -0.903
0.155, 0.468, 3.558, 3.558, -0.938
0.138, 0.464, 3.659, 3.659, -0.973
0.138, 0.461, 3.759, 3.759, -1.007
0.138, 0.457, 3.856, 3.856, -1.040
0.138, 0.454, 3.952, 3.952, -1.072
0.138, 0.451, 4.045, 4.045, -1.103
0.138, 0.448, 4.137, 4.137, -1.134
0.138, 0.445, 4.227, 4.227, -1.164
0.138, 0.442, 4.316, 4.316, -1.194
0.138, 0.439, 4.402, 4.402, -1.222
0.138, 0.436, 4.488, 4.488, -1.251
0.138, 0.434, 4.571, 4.571, -1.278
0.138, 0.431, 4.654, 4.654, -1.305
0.138, 0.429, 4.735, 4.735, -1.332
0.138, 0.427, 4.815, 4.815, -1.358
0.138, 0.424, 4.893, 4.893, -1.384
0.138, 0.422, 4.970, 4.970, -1.409
0.138, 0.420, 5.046, 5.046, -1.434
0.138, 0.418, 5.121, 5.121, -1.458
0.138, 0.416, 5.194, 5.194, -1.482
0.138, 0.414, 5.267, 5.267, -1.505
0.138, 0.412, 5.338, 5.338, -1.528
0.138, 0.411, 5.409, 5.409, -1.551
0.155, 0.409, 5.478, 5.478, -1.573
0.155, 0.407, 5.547, 5.547, -1.595
0.155, 0.406, 5.614, 5.614, -1.617
0.155, 0.404, 5.681, 5.681, -1.638
0.155, 0.402, 5.746, 5.746, -1.659
0.155, 0.401, 5.811, 5.811, -1.680
0.155, 0.399, 5.875, 5.875, -1.700
0.155, 0.398, 5.938, 5.938, -1.720
0.155, 0.397, 6.000, 6.000, -1.740
0.155, 0.395, 6.062, 6.062, -1.760
0.155, 0.394, 6.123, 6.123, -1.779
0.155, 0.393, 6.182, 6.182, -1.798
0.155, 0.391, 6.242, 6.242, -1.816
0.155, 0.390, 6.300, 6.300, -1.835
0.155, 0.389, 6.358, 6.358, -1.853
0.155, 0.388, 6.415, 6.415, -1.871
0.155, 0.386, 6.472, 6.472, -1.889
0.155, 0.385, 6.527, 6.527, -1.906
0.155, 0.384, 6.583, 6.583, -1.924
0.155, 0.383, 6.637, 6.637, -1.941
0.155, 0.382, 6.691, 6.691, -1.958
0.155, 0.381, 6.744, 6.744, -1.974
0.155, 0.380, 6.797, 6.797, -1.991
0.155, 0.379, 6.849, 6.849, -2.007
0.155, 0.378, 6.901, 6.901, -2.023
0.155, 0.377, 6.952, 6.952, -2.039
0.155, 0.376, 7.002, 7.002, -2.055
0.155, 0.375, 7.052, 7.052, -2.070
0.155, 0.374, 7.102, 7.102, -2.086
0.155, 0.374, 7.151, 7.151, -2.101
0.155, 0.373, 7.199, 7.199, -2.116
0.155, 0.372, 7.247, 7.247, -2.131
0.155, 0.371, 7.295, 7.295, -2.145
0.155, 0.370, 7.342, 7.342, -2.160
0.155, 0.370, 7.388, 7.388, -2.174
0.155, 0.369, 7.434, 7.434, -2.188
0.155, 0.368, 7.480, 7.480, -2.203
0.155, 0.367, 7.525, 7.525, -2.216
0.155, 0.367, 7.570, 7.570, -2.230
0.155, 0.366, 7.615, 7.615, -2.244
0.155, 0.365, 7.658, 7.658, -2.257
0.155, 0.364, 7.702, 7.702, -2.271
0.155, 0.364, 7.745, 7.745, -2.284
0.155, 0.363, 7.788, 7.788, -2.297
0.155, 0.362, 7.830, 7.830, -2.310
0.155, 0.362, 7.872, 7.872, -2.323
0.155, 0.361, 7.914, 7.914, -2.336
0.155, 0.361, 7.955, 7.955, -2.348
0.155, 0.360, 7.996, 7.996, -2.361
0.155, 0.359, 8.037, 8.037, -2.373
0.155, 0.359, 8.077, 8.077, -2.386
0.155, 0.358, 8.117, 8.117, -2.398
0.155, 0.358, 8.156, 8.156, -2.410
0.155, 0.357, 8.195, 8.195, -2.422
0.155, 0.357, 8.234, 8.234, -2.434
0.155, 0.356, 8.273, 8.273, -2.445
0.155, 0.356, 8.311, 8.311, -2.457
0.155, 0.355, 8.349, 8.349, -2.469