train_iris.py
コードの補足説明train_iris.py
をそのまま動かす際には scikitlearn が必要となるので、
pip install scikit-learn
conda install scikit-learn
のどちらかを実行してほしい。
In [1]:
from chainer import cuda
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
import pandas as pd
In [2]:
# gpu -> args.gpu と読み替えてほしい
gpu = -1
if gpu >= 0:
# chainer.cuda.get_device(args.gpu).use() # make a specified gpu current
# model.to_gpu() # copy the model to the gpu
xp = cuda.cupy
else:
xp = np
In [3]:
iris = datasets.load_iris()
In [4]:
pd.DataFrame({
'sepal length': np.array([iris.data[x][0] for x in range(150)]), #len(iris.data) -> 150
'sepal width': np.array([iris.data[x][1] for x in range(150)]),
'petal length': np.array([iris.data[x][2] for x in range(150)]),
'petal width': np.array([iris.data[x][3] for x in range(150)]),
'target label': np.array(iris.target)
})
Out[4]:
In [5]:
data_train, data_test, tgt_train, tgt_test = train_test_split(iris.data, iris.target, test_size=0.5)
In [6]:
from collections import Counter
Counter(tgt_train)
Out[6]:
またオプションとして、train_test_sprit
を使わず、
偶数番目のデータを train 、奇数番目のデータを test にするオプションも用意した
実行時の引数に --spritintwo y
or -s y
とすれば実行される
In [7]:
index = np.arange(len(iris.data))
data_train, data_test = iris.data[index[index%2!=0],:], iris.data[index[index%2==0],:]
tgt_train, tgt_test = iris.target[index[index%2!=0]], iris.target[index[index%2==0]]
In [8]:
Counter(tgt_train)
Out[8]:
隠れ層 0 のニューラルネットワークは、 多項ロジスティック回帰と同等 になる。
識別先のクラスを $C = \{c_1, c_2, ... , c_K\}$ 、入力のベクトルを $\mathbf{x} = (x_1, x_2, ... , x_n)^t$ とすると、多項ロジスティック回帰のモデルは、
$$ p(c_k | \mathbf{x}) = \pi (a_k) \\ a_k = (W \mathbf{x} + \mathbf{b})_k $$$\pi$ を softmax 関数とすると、 上式は隠れ0層の NN と等価だとわかる。
コードにおいては、
class MLP_H0(chainer.Chain):
# NO HIDDEN LAYER (= LOGISTIC REGRESSION)
def __init__(self, n_units):
super(MLP_H0, self).__init__(
l1=L.Linear(None, 3), # n_in -> n_out
)
def __call__(self, x):
return self.l1(x)
また L.Classifier
の loss function はデフォルトが softmax_cross_entropy
である
if args.mode == 'MLP_H0':
model = L.Classifier(MLP_H0(args.unit))