我們使用 Gaussian Density Function 來預測 Pokemon 的類型


In [4]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

使用 scipy 來計算 gaussian density 的範例


In [2]:
from scipy.stats import multivariate_normal
x = np.linspace(0, 5, 10, endpoint=False)
y = multivariate_normal.pdf(x, mean=2.5, cov=0.5)

In [5]:
plt.plot(x, y)
plt.show()


我們來載入資料


In [6]:
df= pd.read_csv("./pokemon.csv")

In [7]:
df.columns


Out[7]:
Index([u'name', u'species', u'cp', u'hp', u'weight', u'height',
       u'power_up_stardust', u'power_up_candy', u'attack_weak',
       u'attack_weak_type', u'attack_weak_value', u'attack_strong',
       u'attack_strong_type', u'attack_strong_value', u'cp_new', u'hp_new',
       u'weight_new', u'height_new', u'power_up_stardust_new',
       u'power_up_candy_new', u'attack_weak_new', u'attack_weak_type_new',
       u'attack_weak_value_new', u'attack_strong_new',
       u'attack_strong_type_new', u'attack_strong_value_new', u'notes'],
      dtype='object')

In [8]:
pknormal = df[df['attack_strong_type'] == 'Normal']

In [9]:
pkflying = df[df['attack_strong_type'] == 'Flying']

In [10]:
pknormal


Out[10]:
name species cp hp weight height power_up_stardust power_up_candy attack_weak attack_weak_type ... height_new power_up_stardust_new power_up_candy_new attack_weak_new attack_weak_type_new attack_weak_value_new attack_strong_new attack_strong_type_new attack_strong_value_new notes
39 Weedle1 Weedle 169 45 3.38 0.29 2200 2 Bug Bite Bug ... 0.58 2200 2 Poison Sting Poison 6 Struggle Normal 15 NaN
40 Weedle2 Weedle 15 15 3.27 0.28 200 1 Bug Bite Bug ... 0.55 200 1 Poison Sting Poison 6 Struggle Normal 15 NaN
41 Weedle3 Weedle 25 17 3.11 0.29 400 1 Poison Sting Poison ... 0.58 400 1 Bug Bite Bug 5 Struggle Normal 15 NaN
42 Weedle4 Weedle 222 48 2.93 0.29 2500 2 Poison Sting Poison ... 0.59 2500 2 Bug Bite Bug 5 Struggle Normal 15 NaN
43 Weedle5 Weedle 251 56 4.51 0.34 2500 2 Poison Sting Poison ... 0.67 2500 2 Poison Sting Poison 6 Struggle Normal 15 NaN
44 Weedle6 Weedle 139 46 4.50 0.32 1900 2 Poison Sting Poison ... 0.64 1900 2 Poison Sting Poison 6 Struggle Normal 15 NaN
45 Weedle7 Weedle 45 27 2.55 0.29 600 1 Bug Bite Bug ... 0.58 600 1 Bug Bite Bug 5 Struggle Normal 15 NaN
46 Weedle8 Weedle 129 45 4.01 0.34 1600 2 Bug Bite Bug ... 0.69 1600 2 Bug Bite Bug 5 Struggle Normal 15 NaN
47 Weedle9 Weedle 200 49 2.62 0.27 3000 3 Poison Sting Poison ... 0.55 3000 3 Bug Bite Bug 5 Struggle Normal 15 NaN
48 Weedle10 Weedle 179 47 3.95 0.35 2500 2 Bug Bite Bug ... 0.70 2500 2 Poison Sting Poison 6 Struggle Normal 15 NaN
49 Weedle11 Weedle 104 36 3.48 0.31 1300 2 Poison Sting Poison ... 0.61 1300 2 Poison Sting Poison 6 Struggle Normal 15 NaN
50 Weedle12 Weedle 144 48 2.60 0.29 1900 2 Poison Sting Poison ... 0.57 1900 2 Bug Bite Bug 5 Struggle Normal 15 NaN
51 Weedle13 Weedle 157 49 2.05 0.28 2200 2 Bug Bite Bug ... 0.56 2200 2 Poison Sting Poison 6 Struggle Normal 15 NaN
52 Weedle14 Weedle 106 39 5.31 0.39 1300 2 Bug Bite Bug ... 0.77 1300 2 Poison Sting Poison 6 Struggle Normal 15 NaN
53 Weedle15 Weedle 203 52 3.54 0.30 3000 3 Bug Bite Bug ... 0.59 3000 3 Poison Sting Poison 6 Struggle Normal 15 NaN
54 Weedle16 Weedle 16 13 2.68 0.26 200 1 Bug Bite Bug ... 0.52 200 1 Bug Bite Bug 5 Struggle Normal 15 NaN
55 Weedle17 Weedle 240 54 4.33 0.35 2500 2 Poison Sting Poison ... 0.71 2500 2 Bug Bite Bug 5 Struggle Normal 15 NaN
56 Weedle18 Weedle 60 27 3.46 0.32 800 1 Poison Sting Poison ... 0.63 800 1 Bug Bite Bug 5 Struggle Normal 15 NaN
57 Weedle19 Weedle 80 35 1.69 0.22 800 1 Poison Sting Poison ... 0.44 800 1 Poison Sting Poison 6 Struggle Normal 15 NaN
58 Weedle20 Weedle 75 30 3.52 0.35 800 1 Bug Bite Bug ... 0.70 800 1 Bug Bite Bug 5 Struggle Normal 15 NaN
59 Caterpie1 Caterpie 136 44 4.68 0.35 1600 2 Bug Bite Bug ... 0.82 1600 2 Tackle Normal 12 Struggle Normal 15 NaN
60 Caterpie2 Caterpie 231 62 3.80 0.32 2500 2 Bug Bite Bug ... 0.75 2500 2 Tackle Normal 12 Struggle Normal 15 NaN
61 Caterpie3 Caterpie 43 26 2.68 0.28 600 1 Bug Bite Bug ... 0.65 600 1 Tackle Normal 12 Struggle Normal 15 NaN
62 Caterpie4 Caterpie 140 49 2.70 0.30 1900 2 Tackle Normal ... 0.71 1900 2 Bug Bite Bug 5 Struggle Normal 15 NaN
63 Caterpie5 Caterpie 208 53 2.24 0.29 2500 2 Tackle Normal ... 0.67 2500 2 Tackle Normal 12 Struggle Normal 15 NaN
64 Caterpie6 Caterpie 128 46 1.69 0.24 1600 2 Tackle Normal ... 0.57 1600 2 Bug Bite Bug 5 Struggle Normal 15 NaN
65 Caterpie7 Caterpie 186 56 3.14 0.27 3000 3 Tackle Normal ... 0.64 3000 3 Tackle Normal 12 Struggle Normal 15 NaN
66 Caterpie8 Caterpie 149 46 2.69 0.31 1900 2 Bug Bite Bug ... 0.72 1900 2 Tackle Normal 12 Struggle Normal 15 NaN
67 Caterpie9 Caterpie 10 10 1.87 0.29 200 1 Bug Bite Bug ... 0.69 200 1 Tackle Normal 12 Struggle Normal 15 NaN
68 Caterpie10 Caterpie 86 10 3.50 0.30 1000 1 Tackle Normal ... 0.70 1000 1 Bug Bite Bug 5 Struggle Normal 15 NaN
69 Eevee1 Eevee 619 74 2.87 0.20 3000 3 Quick Attack Normal ... 0.68 3000 3 Water Gun Water 6 Aqua Tail Water 45 Vaporean, attack details added on date later t...
71 Eevee7 Eevee 606 74 10.42 0.38 2500 2 Tackle Normal ... 1.26 2500 2 Water Gun Water 6 Aqua Tail Water 45 Vaporeon, attack details added on date later t...
72 Eevee8 Eevee 548 66 6.87 0.32 2500 2 Tackle Normal ... 0.95 2500 2 Ember Fire 10 Heat Wave Fire 80 Flareon, attack details added on date later th...
74 Eevee10 Eevee 517 72 6.63 0.29 2500 2 Tackle Normal ... 0.88 2500 2 Ember Fire 10 Heat Wave Fire 80 Flareon, attack details added on date later th...

34 rows × 27 columns


In [11]:
pkflying


Out[11]:
name species cp hp weight height power_up_stardust power_up_candy attack_weak attack_weak_type ... height_new power_up_stardust_new power_up_candy_new attack_weak_new attack_weak_type_new attack_weak_value_new attack_strong_new attack_strong_type_new attack_strong_value_new notes
0 Pidgey1 Pidgey 384 56 2.31 0.34 2500 2 Tackle Normal ... 1.24 2500 2 Steel Wing Steel 15 Air Cutter Flying 30 NaN
2 Pidgey3 Pidgey 353 55 1.94 0.30 3000 3 Quick Attack Normal ... 1.11 3000 3 Wing Attack Flying 9 Air Cutter Flying 30 NaN
3 Pidgey4 Pidgey 338 51 1.73 0.31 3000 3 Tackle Normal ... 1.12 3000 3 Steel Wing Steel 15 Air Cutter Flying 30 NaN
4 Pidgey5 Pidgey 242 45 1.44 0.27 1900 2 Quick Attack Normal ... 0.98 1900 2 Wing Attack Flying 9 Twister Dragon 25 NaN
5 Pidgey6 Pidgey 129 35 2.07 0.35 800 1 Quick Attack Normal ... 1.27 800 1 Wing Attack Flying 9 Aerial Ace Flying 30 NaN
6 Pidgey7 Pidgey 10 10 0.92 0.25 200 1 Tackle Normal ... 0.90 200 1 Wing Attack Flying 9 Air Cutter Flying 30 NaN
10 Pidgey11 Pidgey 114 31 1.58 0.26 800 1 Tackle Normal ... 0.96 800 1 Wing Attack Flying 9 Aerial Ace Flying 30 NaN
11 Pidgey12 Pidgey 333 52 1.85 0.30 3000 3 Tackle Normal ... 1.10 3000 3 Steel Wing Steel 15 Aerial Ace Flying 30 NaN
12 Pidgey13 Pidgey 132 33 1.63 0.28 800 1 Quick Attack Normal ... 1.03 800 1 Wing Attack Flying 9 Air Cutter Flying 30 NaN
14 Pidgey15 Pidgey 42 19 2.01 0.30 400 1 Quick Attack Normal ... 1.11 400 1 Wing Attack Flying 9 Aerial Ace Flying 30 NaN
17 Pidgey18 Pidgey 330 48 1.62 0.29 2500 2 Quick Attack Normal ... 1.08 2500 2 Wing Attack Flying 9 Aerial Ace Flying 30 weight_new is accurate
19 Pidgey20 Pidgey 176 35 1.81 0.29 1000 1 Tackle Normal ... 1.07 1000 1 Steel Wing Steel 15 Twister Dragon 25 NaN
22 Pidgey23 Pidgey 127 32 2.33 0.34 800 1 Tackle Normal ... 1.23 800 1 Steel Wing Steel 15 Air Cutter Flying 30 NaN
24 Pidgey25 Pidgey 240 41 2.53 0.33 1600 2 Tackle Normal ... 1.21 1600 2 Wing Attack Flying 9 Air Cutter Flying 30 NaN
25 Pidgey26 Pidgey 276 48 0.87 0.24 2200 2 Tackle Normal ... 0.86 2200 2 Steel Wing Steel 15 Aerial Ace Flying 30 NaN
26 Pidgey27 Pidgey 207 41 1.19 0.26 1600 2 Quick Attack Normal ... 0.96 1600 2 Steel Wing Steel 15 Air Cutter Flying 30 NaN
27 Pidgey28 Pidgey 176 35 1.68 0.31 1300 2 Tackle Normal ... 1.13 1300 2 Steel Wing Steel 15 Aerial Ace Flying 30 NaN
28 Pidgey29 Pidgey 316 47 1.71 0.27 2500 2 Quick Attack Normal ... 0.99 2500 2 Wing Attack Flying 9 Twister Dragon 25 NaN
30 Pidgey31 Pidgey 304 46 1.99 0.31 2500 2 Tackle Normal ... 1.12 2500 2 Steel Wing Steel 15 Twister Dragon 25 NaN
33 Pidgey34 Pidgey 226 40 2.04 0.31 1600 2 Tackle Normal ... 1.12 1600 2 Steel Wing Steel 15 Aerial Ace Flying 30 NaN
34 Pidgey35 Pidgey 220 43 2.13 0.30 1600 2 Tackle Normal ... 1.10 1600 2 Steel Wing Steel 15 Twister Dragon 25 NaN
36 Pidgey37 Pidgey 42 20 2.11 0.33 400 1 Quick Attack Normal ... 1.21 400 1 Wing Attack Flying 9 Aerial Ace Flying 30 NaN
37 Pidgey38 Pidgey 344 56 2.57 0.35 3000 3 Tackle Normal ... 1.29 3000 3 Steel Wing Steel 15 Twister Dragon 25 NaN
38 Pidgey39 Pidgey 108 31 1.19 0.25 800 1 Tackle Normal ... 0.91 800 1 Wing Attack Flying 9 Air Cutter Flying 30 NaN

24 rows × 27 columns

假定題目我們要來預測寶可夢的種類

  • 為二元分類問題: 是 fying 或者是 normal
  • 直觀的假設 HP 和 Weight 為 Strong Feature

In [12]:
normal_f1 = pknormal['hp']
normal_f2 = pknormal['weight']

flying_f1 = pkflying['hp']
flying_f2 = pkflying['weight']

In [13]:
x_c1 = np.array(zip(normal_f1.tolist(),normal_f2.tolist()))
x_c2 = np.array(zip(flying_f1.tolist(),flying_f2.tolist()))
c1_mean = np.mean(x_c1,axis=0)
c2_mean = np.mean(x_c2,axis=0)
xc1_xc1 = x_c1-c1_mean
xc2_xc2 = x_c2-c2_mean

In [14]:
cov_c1 = np.matmul(xc1_xc1.transpose(),xc1_xc1)/c1_mean.shape[0]
cov_c2 = np.matmul(xc2_xc2.transpose(),xc2_xc2)/c2_mean.shape[0]

In [15]:
x = np.append(x_c1,x_c2,axis=0)

In [16]:
y = [0] * x_c1.shape[0] + [1]*x_c2.shape[0]

從下圖看來,兩個 Gaussian Density Function 所產生出來的 Probability 基準不一至


In [17]:
cov = (cov_c1*len(x_c1) + cov_c2*len(x_c2))/(len(x_c1)+len(x_c2))
y1 = multivariate_normal.pdf(x, mean=c1_mean, cov=cov_c1)
y2 = multivariate_normal.pdf(x, mean=c2_mean, cov=cov_c2)
plt.plot(y1)
plt.plot(y2)
plt.show()



In [18]:
error =0
for label,c1p,c2p in zip(y,y1,y2):
    if c1p > c2p and label == 1:
        error += 1
    elif c1p < c2p and label ==0:
        error +=1
print("錯了 {} 個".format(error))


錯了 30 個

採用 share covarience matrix,來降 model 的 varience

  • 用白話來說讓兩個類別的 density function 基準是一至的

In [19]:
cov = (cov_c1*len(x_c1) + cov_c2*len(x_c2))/(len(x_c1)+len(x_c2))

y1 = multivariate_normal.pdf(x, mean=c1_mean, cov=cov)
y2 = multivariate_normal.pdf(x, mean=c2_mean, cov=cov)
plt.plot(y1)
plt.plot(y2)
plt.show()



In [20]:
error =0
for label,c1p,c2p in zip(y,y1,y2):
    if c1p > c2p and label == 1:
        error += 1
    elif c1p < c2p and label ==0:
        error +=1
print("錯了 {} 個".format(error))


錯了 10 個

In [ ]: