第2章 [実践] ニューラルネットワーク

このノートブックでは、第2章で説明した3層パーセプトロンによる手書き数字認識を再現します。

環境構築

このノートブックの内容をお手元のコンピュータ上で再現するために必要な環境は、以下のコマンドを実行することで構築できます。

$ pip install -r requirements.txt

CUDA を使用できる場合は以下のコマンドも実行してください。

$ pip install chainer_cuda_deps

準備

ノートブックの環境を準備します。

matplotlib の設定

グラフが画像を表示するために使用する matplotlib の設定をします。


In [1]:
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt

以下はグラフ用と画像の表示設定です。筆者の環境は Mac なので Osaka フォントを指定しています。


In [2]:
mpl.rcParams['font.family'] = [u'Osaka']
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

ライブラリのインポート

ノートブックで使用するライブラリをインポートします。


In [3]:
import sys
import numpy as np
import pandas as pd
from PIL import Image
import chainer
import chainer.functions as F
import chainer.optimizers
from chainer import cuda
import mnist

GPU の設定

以下は、GPU が使えるかどうかを検査します。使える場合は use_gpuTrue に設定されます。


In [4]:
gpu_device_id = 0
try:
    cuda.init(gpu_device_id)
    use_gpu = True
    print "GPU %d is available." % (gpu_device_id)
except:
    use_gpu = False
    print "GPU %d is not available." % (gpu_device_id)
    print "Error: ", sys.exc_info()[0], sys.exc_info()[1]


GPU 0 is available.

MNIST データ

MNIST データをダウンロードし、それをロードして内容を確認します。

データのダウンロード

データのダウンロードには付属のスクリプトを使用します。


In [5]:
!./download_mnist.sh


Downloading train-images-idx3-ubyte.gz
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 9680k  100 9680k    0     0  74606      0  0:02:12  0:02:12 --:--:-- 54765
Downloading train-labels-idx1-ubyte.gz
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 28881  100 28881    0     0  21476      0  0:00:01  0:00:01 --:--:-- 21488
Downloading t10k-images-idx3-ubyte.gz
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 1610k  100 1610k    0     0  65852      0  0:00:25  0:00:25 --:--:--  338k
Downloading t10k-labels-idx1-ubyte.gz
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  4542  100  4542    0     0   1629      0  0:00:02  0:00:02 --:--:--  1629

データのロード

次に、データをロードします。

データをロードする処理は mnist.py に記載しています。興味がある方はそちらのスクリプトを参照してください。


In [6]:
# MNIST データをロード
mnist_data = mnist.load_data()

# MNIST 画像データは、8ビットのグレースケール画像である。
# これを、ピクセルの輝度が 0〜1 の浮動小数点数で表現される形式へ変換する。
images = mnist_data['images'].astype(np.float32)
images /= 255

# ラベルデータは整数として扱う
labels = mnist_data['labels'].astype(np.int32)

# images と labels は、訓練用データと検証用データを結合したもの。
# ここから、訓練用、検証用のそれぞれのデータを分離すr

train_count = mnist_data['train_count']  # 訓練用データの個数
test_count = mnist_data['test_count']    # 検証用データの個数

# 画像を訓練用 x_train と検証用 x_test に分離
x_train = images[0:train_count]
x_test  = images[train_count:]

# ラベルを訓練用 t_train と検証用 t_test に分離
t_train = labels[0:train_count]
t_test  = labels[train_count:]

データの確認

x_traint_train からデータをいくつか取り出して確認します。


In [7]:
# MNIST 画像を並べて表示する関数
def display_mnist_random(ncols, nrows, pad=1):
    # 表示する画像数
    count = ncols * nrows

    # ランダムに count 個の画像を取り出すためのインデックス
    random_indices = np.random.permutation(train_count)[0:count]

    patch_width  = 28 # MNIST 画像の幅
    patch_height = 28 # MNIST 画像の高さ
    pad = 1 # 画像の間を1ピクセル空ける
    total_width = ncols * patch_width + pad * (ncols - 1)   # 全体の幅
    total_height = nrows * patch_height + pad * (nrows - 1) # 全体の高さ

    # MNIST 画像を配置するためのグレースケール画像を生成し、白 (255) で塗り潰す
    total_image = Image.new('L', (total_width, total_height), 255)
    
    # MNIST 画像に対応するラベルを集めるためのベクトル
    total_labels = np.ndarray((nrows, ncols), dtype=np.int64)

    # MNIST 画像を配置する
    for i in range(0, nrows):
        for j in range(0, ncols):
            index = i * nrows + j
            patch = x_train[index, :].reshape((patch_width, patch_height))
            subimage = Image.fromarray(patch * 255)
            total_image.paste(subimage, (j*(patch_width + pad), i*(patch_height + pad)))
            total_labels[i, j] = t_train[index]

    # 画像を表示する
    fig = plt.figure()
    fig.set_size_inches(10, 10)
    plt.imshow(total_image, cmap=plt.cm.Greys_r, interpolation="none")
    plt.axis('off')

    # 画像と対応するラベルを表示する
    print(total_labels)

display_mnist_random(8, 4)


[[5 0 4 1 9 2 1 3]
 [9 2 1 3 1 4 3 5]
 [1 4 3 5 3 6 1 7]
 [3 6 1 7 2 8 6 9]]

ネットワーク構造の定義


In [8]:
model = chainer.FunctionSet(
    fc1=F.Linear(784, 64),  # 入力層から隠れ層への全結合
    fc2=F.Linear( 64, 10),  # 隠れ層から出力層への全結合
)

GPU が使える場合 (use_gpuTrue の場合) は、ネットワークのパラメータを GPU 用に変換しておきます。


In [9]:
if use_gpu:
    model.to_gpu()

出力の計算


In [10]:
def forward(x):
    u2 = model.fc1(x)   # 入力層から隠れ層への結合
    z2 = F.relu(u2)     # 隠れ層の活性化関数
    u3 = model.fc2(z2)  # 隠れ層から出力層への結合
    return u3

def output(x):
    h = forward(x)
    return F.softmax(h)  # 出力層の活性化関数

def predict(x):
    y = output(x)                      # 確率ベクトル
    d = np.argmax(cuda.to_cpu(y.data)) # 予測数字
    return d

predict 関数内で使っている cuda.to_cpu は、引数で与えられたベクトルが GPU 用に変換されている場合に、それを CPU 用に戻す処理を行います。

バックプロパゲーションによる学習

損失関数を定義します。


In [11]:
def loss(h, t):
    return F.softmax_cross_entropy(h, t)

最適化器を定義します。


In [12]:
optimizer = chainer.optimizers.Adam()
optimizer.setup(model.collect_parameters())

以下の mini_batch_learn 関数は、誌面に掲載したものに GPU 対応コードを入れ、正則化の強さを引数で指定できるようにしてあります。


In [13]:
# ミニバッチ方式による学習
# * x_train - 訓練用の入力データ
# * t_train - 訓練用の教師データ
# * mini_batch_size - ミニバッチのサイズ
# * wd - 正則化の強さ
def mini_batch_learn(x_train, t_train, mini_batch_size=10, wd=0.001):
    sum_loss = 0         # 全体の誤差を保持する変数
    sum_accuracy = 0     # 全体の正解率を保持する変数

    # データの無作為な並び換えを作成
    train_count = len(x_train)
    perm = np.random.permutation(train_count)

    # ミニバッチ単位で学習を進める
    for i in range(0, train_count, mini_batch_size):
        # ミニバッチを取り出す
        x_batch = x_train[perm[i:i + mini_batch_size]]
        t_batch = t_train[perm[i:i + mini_batch_size]]

        # GPU が使える場合はベクトルを GPU 用に変換する
        if use_gpu:
            x_batch = cuda.to_gpu(x_batch)
            t_batch = cuda.to_gpu(t_batch)

        optimizer.zero_grads()         # 勾配をゼロで初期化

        x = chainer.Variable(x_batch)  # 入力データ
        t = chainer.Variable(t_batch)  # 教師データ
        h = forward(x)
        e = loss(h, t)                 # 誤差を計算
        a = F.accuracy(h, t)           # 正解率を計算

        e.backward()                   # 誤差を逆伝搬
        optimizer.weight_decay(wd)     # 正則化
        optimizer.update()             # パラメータを更新

        # 全体の誤差と正解率に加算
        sum_loss     += float(cuda.to_cpu(e.data)) * len(t_batch)
        sum_accuracy += float(cuda.to_cpu(a.data)) * len(t_batch)

    # 誤差と正解率を返す
    train_loss     = sum_loss / train_count
    train_accuracy = sum_accuracy / train_count
    return train_loss, train_accuracy

学習結果の評価

以下の mini_batch_test 関数は誌面に掲載したものに、GPU 対応コードを入れています。


In [14]:
# ミニバッチ方式による評価
# * x_test - 評価用の入力データ
# * t_test - 評価用の教師データ
# * mini_batch_size - ミニバッチのサイズ
def mini_batch_test(x_test, t_test, mini_batch_size=10):
    sum_loss = 0         # 全体の誤差を保持する変数
    sum_accuracy = 0     # 全体の正解率を保持する変数

    test_count = len(x_test)
    for i in range(0, test_count, mini_batch_size):
        # ミニバッチを取り出す
        x_batch = x_test[i:i + mini_batch_size]
        t_batch = t_test[i:i + mini_batch_size]

        # GPU が使える場合はベクトルを GPU 用に変換する
        if use_gpu:
            x_batch = cuda.to_gpu(x_batch)
            t_batch = cuda.to_gpu(t_batch)

        x = chainer.Variable(x_batch)  # 入力データ
        t = chainer.Variable(t_batch)  # 教師データ
        h = forward(x)
        e = loss(h, t)                 # 誤差を計算
        a = F.accuracy(h, t)           # 正解率を計算

        # 全体の誤差と正解率に加算
        sum_loss     += float(cuda.to_cpu(e.data)) * len(t_batch)
        sum_accuracy += float(cuda.to_cpu(a.data)) * len(t_batch)

    # 誤差と正解率を返す
    test_loss     = sum_loss / test_count
    test_accuracy = sum_accuracy / test_count
    return test_loss, test_accuracy

学習、評価の繰り返しと結果の可視化

訓練の実行

ネットワークの訓練を実行します。時間がかかるため、途中経過を表の形で出力するようになっています。 この表は、変数 data_normal に記録される実行結果と同じものです。


In [15]:
# 誤差と正解率の記録用
data_normal = {
    'epoch': [],
    'train_loss': [],
    'train_accuracy': [],
    'test_loss': [],
    'test_accuracy': [],
}

max_epoch = 50  # 最大の学習回数

# 進捗ログのヘッダ
print("epoch\ttrain_loss\ttrain_accuracy\ttest_loss\ttest_accuracy")

for epoch in range(1, max_epoch + 1):
    data_normal['epoch'].append(epoch)

    # 訓練用データによる学習
    train_e, train_a = mini_batch_learn(x_train, t_train, mini_batch_size=50)
    data_normal['train_loss'].append(train_e)
    data_normal['train_accuracy'].append(train_a)

    # テスト用データによる評価
    test_e, test_a = mini_batch_test(x_test, t_test, mini_batch_size=50)
    data_normal['test_loss'].append(test_e)
    data_normal['test_accuracy'].append(test_a)

    # 進捗ログを出力
    print("{}\t{}\t{}\t{}\t{}".format(epoch, train_e, train_a, test_e, test_a))
    sys.stdout.flush()


epoch	train_loss	train_accuracy	test_loss	test_accuracy
1	0.359655061016	0.903966643897	0.224881760217	0.935099977553
2	0.191943332489	0.9463666448	0.165690085366	0.955299979448
3	0.153881382961	0.956816644271	0.134827750635	0.964099979997
4	0.132305102105	0.963666643649	0.130220859671	0.963199981749
5	0.119586570694	0.967249978185	0.116096135568	0.96549998194
6	0.111350437868	0.969783312033	0.112139235398	0.966199980974
7	0.105423816336	0.970833312174	0.109904114716	0.968799981475
8	0.100152748745	0.973349979271	0.10326864977	0.969899981022
9	0.0982701934145	0.973366645426	0.102664483009	0.971499982476
10	0.0950609500186	0.974683312724	0.0981472349435	0.971899981499
11	0.0937736817674	0.974733312478	0.0980173065176	0.971299982667
12	0.0918106202758	0.975683312515	0.113204626072	0.968699981272
13	0.0909998125102	0.975233312647	0.0953121963335	0.97329998225
14	0.0892766897194	0.976183312635	0.0978553123656	0.972799982131
15	0.0882546662213	0.976116645386	0.0963684241881	0.974399983883
16	0.087742462776	0.976516646047	0.0947780142719	0.974899982214
17	0.0874628156532	0.976433312049	0.0955648986832	0.97299998492
18	0.0863385855546	0.977183312128	0.0956249162758	0.97459998399
19	0.0861045261395	0.977116645773	0.0859532257007	0.976399983764
20	0.0852446304254	0.977266646177	0.0898616158124	0.974499983191
21	0.085004784686	0.977499979834	0.0922170675203	0.974299983084
22	0.0839522645685	0.977766647438	0.0880183293414	0.97449998498
23	0.0852502167753	0.977816646447	0.0930008199421	0.974799984694
24	0.0837275226826	0.977766646047	0.0905026924657	0.975099983811
25	0.0830854327907	0.977949979852	0.0942565514764	0.973399982452
26	0.0836270677935	0.978399979273	0.0952348391782	0.971699983478
27	0.0832614008671	0.977733312994	0.0926066780055	0.974499984086
28	0.0831181173368	0.977683313092	0.0883273574733	0.974999984503
29	0.0822636285354	0.97824998046	0.0904654508166	0.974099982679
30	0.083272093973	0.977916646103	0.0869425298949	0.97699998498
31	0.0820789040097	0.977983313153	0.0872625438508	0.974699983001
32	0.081716004624	0.978583313028	0.0933940995112	0.974499984086
33	0.0823466015064	0.978216646264	0.087607504882	0.975399983823
34	0.0811822122405	0.978816646785	0.0873254689178	0.975399983227
35	0.0816339547761	0.978533313076	0.0917922851769	0.975099984407
36	0.082261967437	0.977983313402	0.0873258681264	0.974299982488
37	0.0812257956426	0.979066647192	0.0865767294611	0.974299981892
38	0.0810708712321	0.978549978584	0.0872671215358	0.975399985313
39	0.0809498288273	0.978666645636	0.0909749849513	0.974199982882
40	0.0806847720616	0.979183313549	0.0875043327257	0.975999985039
41	0.0804163613791	0.979149980893	0.0898726037366	0.975099983811
42	0.0809078967905	0.978549979677	0.0895405822824	0.973999985456
43	0.0809142092724	0.978399979472	0.0886359694542	0.974999983311
44	0.0808429462874	0.978583313177	0.0954292944336	0.972399981618
45	0.0804177693177	0.978649979929	0.0870750140271	0.975699984133
46	0.0804200512241	0.978816646685	0.0944644015381	0.973099983335
47	0.079932305885	0.978399980019	0.0925784226554	0.973899983764
48	0.0812201662625	0.978866646737	0.10069525082	0.96999998033
49	0.0801326512258	0.978783313682	0.0900188406464	0.974199982584
50	0.0801330439281	0.978833313535	0.0871437774424	0.976499985754

結果の可視化

次に結果の可視化です。学習結果 data_normal を pandas のデータフレームに変換し、データフレームの機能を使ってグラフを描画します。


In [16]:
# 学習結果のデータフレームを作成
df_normal = pd.DataFrame(data_normal)

誤差の折れ線グラフを描きます。


In [17]:
# 誤差の折れ線グラフ
fig = plt.figure()
ax = df_normal.plot(x='epoch', y='train_loss', style="k-")
df_normal.plot(x='epoch', y='test_loss', style="k--", dashes=(3, 1.5), ax=ax)
ax.set_xlabel(u'学習ステップ')
ax.set_ylabel(u'誤差')
ax.legend(labels=[u'訓練', u'テスト'])
plt.tight_layout()


<matplotlib.figure.Figure at 0x11a9acf90>

正解率の折れ線グラフを描きます。


In [18]:
# 正解率の折れ線グラフ
fig = plt.figure()
ax = df_normal.plot(x='epoch', y='train_accuracy', style="k-")
df_normal.plot(x='epoch', y='test_accuracy', style="k--", dashes=(3, 1.5), ax=ax)
ax.set_xlabel(u'学習ステップ')
ax.set_ylabel(u'正解率')
ax.legend(labels=[u'訓練', u'テスト'], loc='lower right')
plt.tight_layout()


<matplotlib.figure.Figure at 0x130e21dd0>

過学習する場合

正則化を弱めて過学習を起こしてみます。


In [19]:
# 後で可視化のため、学習済みのパラメータを保存
fc1_normal = model.fc1
fc2_normal = model.fc2

# 過学習を再現するためのモデルを初期化
model = chainer.FunctionSet(
    fc1=F.Linear(784, 64),  # 入力層から隠れ層への全結合
    fc2=F.Linear( 64, 10),  # 隠れ層から出力層への全結合
)

if use_gpu:
    model.to_gpu()

optimizer = chainer.optimizers.Adam()
optimizer.setup(model.collect_parameters())

# 誤差と正解率の記録用
data_overfit = {
    'epoch': [],
    'train_loss': [],
    'train_accuracy': [],
    'test_loss': [],
    'test_accuracy': [],
}
max_epoch = 50  # 最大の学習回数

# 進捗ログのヘッダ
print("epoch\ttrain_loss\ttrain_accuracy\ttest_loss\ttest_accuracy")

for epoch in range(1, max_epoch + 1):
    data_overfit['epoch'].append(epoch)

    # 訓練用データによる学習
    train_e, train_a = mini_batch_learn(x_train, t_train, mini_batch_size=50, wd=0.0001)
    data_overfit['train_loss'].append(train_e)
    data_overfit['train_accuracy'].append(train_a)

    # テスト用データによる評価
    test_e, test_a = mini_batch_test(x_test, t_test, mini_batch_size=50)
    data_overfit['test_loss'].append(test_e)
    data_overfit['test_accuracy'].append(test_a)

    # 進捗ログを出力
    print("{}\t{}\t{}\t{}\t{}".format(epoch, train_e, train_a, test_e, test_a))
    sys.stdout.flush()


epoch	train_loss	train_accuracy	test_loss	test_accuracy
1	0.343428140972	0.905466643547	0.194564190011	0.941299978793
2	0.169079585131	0.951416644255	0.145713470117	0.956599978805
3	0.125221656743	0.964449977527	0.114407486626	0.966299981773
4	0.0995551423371	0.971166645288	0.100288384082	0.970699985027
5	0.0824073036896	0.976616645604	0.0888286539604	0.973099982142
6	0.0708130922075	0.979033313493	0.0768014486926	0.975699984133
7	0.0618805315318	0.981983314306	0.0793604097835	0.97429998368
8	0.0553947909058	0.983683316062	0.0740876693874	0.975899981856
9	0.0500163263203	0.98533331578	0.0799005864194	0.975399983227
10	0.0458989543574	0.98651665017	0.0779694265859	0.975199982822
11	0.0420274807532	0.987683317463	0.0747612082324	0.976799982488
12	0.0385848575258	0.988833318849	0.07003085741	0.97649998337
13	0.0357989175025	0.989883319587	0.0710867349312	0.977099981904
14	0.0343326061325	0.990083318303	0.0694025605466	0.978099983633
15	0.0318878220455	0.990749986966	0.0733289850564	0.976699982285
16	0.0305617100241	0.991499987543	0.0753733393908	0.974999983311
17	0.028011661561	0.991899987459	0.0714968280259	0.978299983442
18	0.0276963787262	0.992066654513	0.0772402072982	0.974599982798
19	0.0260463347967	0.992899988492	0.0727367349036	0.977199980617
20	0.0250573464783	0.993149988751	0.0755888141892	0.975299981833
21	0.023926104053	0.993533323159	0.0698366079118	0.978199982345
22	0.023945006195	0.993599989514	0.0747423777091	0.976499982178
23	0.0226393654579	0.993933323572	0.0694091012409	0.978299981654
24	0.022028136672	0.994083324124	0.0759969412056	0.975799982548
25	0.0208378163635	0.994449991335	0.0739873579875	0.976499982476
26	0.0212746098176	0.994183324625	0.0735700956465	0.977499982119
27	0.0198058848317	0.994616657595	0.0836060139413	0.974799981415
28	0.0205076431163	0.994416658133	0.0733051468054	0.977099983394
29	0.0201665252843	0.994633324544	0.0810852317711	0.974199982285
30	0.019953020424	0.994766658346	0.0716299171904	0.977799982131
31	0.0191351806238	0.995049991806	0.0750536189553	0.97739998281
32	0.0186887535426	0.995199991763	0.0739121172769	0.976699982584
33	0.0188708610442	0.995199992359	0.0737601221621	0.976599981487
34	0.0181476659381	0.99538332507	0.0753856973259	0.977299983799
35	0.0184081341865	0.995366658767	0.0703331617602	0.978499981165
36	0.0175307102532	0.995816659878	0.0728765718833	0.977399982214
37	0.0173161948615	0.995699992776	0.075138355291	0.976599983573
38	0.0169783935086	0.995699992528	0.075146034296	0.97649998188
39	0.017431540849	0.99578332603	0.0722596269169	0.977499983013
40	0.0169591181786	0.995849992832	0.0742281734749	0.977299982011
41	0.0168114545172	0.996066660136	0.0761315774655	0.97699998349
42	0.0169340229075	0.995649992873	0.0881071453958	0.973199982047
43	0.0160836212821	0.995999993334	0.0768986437484	0.976499982774
44	0.0161285047917	0.996316660295	0.0726908422959	0.977699983418
45	0.0167129195567	0.995899992685	0.0744199643217	0.97809998244
46	0.0175403467883	0.995383324971	0.0793827724612	0.975399982631
47	0.0150739175136	0.996616660357	0.0783011073606	0.976499981284
48	0.0165075541224	0.99609999314	0.0846949970691	0.974799982309
49	0.0168119139272	0.995633325974	0.0712057356638	0.97799998194
50	0.0164745597076	0.995866659681	0.0784969747218	0.976299983859

In [20]:
# 学習結果のデータフレームを作成
df_overfit = pd.DataFrame(data_overfit)

In [21]:
# 誤差の折れ線グラフ
fig = plt.figure()
ax = df_overfit.plot(x='epoch', y='train_loss', style="k-")
df_overfit.plot(x='epoch', y='test_loss', style="k--", dashes=(3, 1.5), ax=ax)
ax.set_xlabel(u'学習ステップ')
ax.set_ylabel(u'誤差')
ax.legend(labels=[u'訓練', u'テスト'])
plt.tight_layout()


<matplotlib.figure.Figure at 0x130fac7d0>

In [22]:
# 正解率の折れ線グラフ
fig = plt.figure()
ax = df_overfit.plot(x='epoch', y='train_accuracy', style="k-")
df_overfit.plot(x='epoch', y='test_accuracy', style="k--", dashes=(3, 1.5), ax=ax)
ax.set_xlabel(u'学習ステップ')
ax.set_ylabel(u'正解率')
ax.legend(labels=[u'訓練', u'テスト'], loc='lower right')
plt.tight_layout()


<matplotlib.figure.Figure at 0x131344f90>

隠れ層と出力層の可視化

隠れ層と出力層は以下の関数で可視化します。


In [23]:
# take an array of shape (n, height, width) or (n, height, width, channels)
# and visualize each (height, width) thing in a grid of size approx. sqrt(n) by sqrt(n)
def display_layer(weights, padsize=1, padval=0, shape=None):
    n = weights.shape[0]
    patch_width = weights.shape[1]
    patch_height = weights.shape[2]

    if shape != None:
        nrows = shape[0]
        ncols = shape[1]
    else:
        ncols = int(np.ceil(np.sqrt(n)))
        nrows = int(np.ceil(float(n) / ncols))

    image_width  = ncols * patch_width  + padsize*(ncols - 1)
    image_height = nrows * patch_height + padsize*(nrows - 1)

    image = Image.new('L', (image_width, image_height), padval*255)
    
    for i in range(0, nrows):
        for j in range(0, ncols):
            k = i * ncols + j
            if k < n:
                patch = weights[k, :].reshape((patch_width, patch_height))
                patch -= np.min(patch)
                patch /= np.max(patch)
                patch *= 255
                subimage = Image.fromarray(patch)
                image.paste(subimage, (j*(patch_width + padsize), i*(patch_height + padsize)))

    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111)
    ax.imshow(image)
    plt.axis('off')
    plt.tight_layout()

まずは、隠れ層です。


In [24]:
display_layer(cuda.to_cpu(model.fc1.W).reshape((64, 28, 28)), padval=1)


そして、隠れ層です。


In [25]:
display_layer(cuda.to_cpu(model.fc2.W).reshape((10, 8, 8)), padval=1, shape=(2, 5))


隠れ層のニューロンが増減するとどうなるか

隠れ層のニューロン数の変化に対する、誤差と正解率と計算時間の変化を調べます。


In [26]:
import time

# 誤差と正解率の記録用
data_variable_size = {
    'hidden_size': [],
    'duration': [],
    'train_loss': [],
    'train_accuracy': [],
    'test_loss': [],
    'test_accuracy': [],
}

max_epoch = 50  # 最大の学習回数

for n in range(9, 2, -1):
    hidden_size = n*n
    data_variable_size['hidden_size'].append(hidden_size)

    model = chainer.FunctionSet(
        fc1=F.Linear(784, hidden_size),  # 入力層から隠れ層への全結合
        fc2=F.Linear(hidden_size,  10),  # 隠れ層から出力層への全結合
    )

    if use_gpu:
        model.to_gpu()

    optimizer = chainer.optimizers.Adam()
    optimizer.setup(model.collect_parameters())

    # 進捗ログのヘッダ
    print("hidden_size\tepoch\ttrain_loss\ttrain_accuracy\ttest_loss\ttest_accuracy")
    sys.stdout.flush()

    # 実行時間を簡易的に計測するために開始時刻を記録
    started_at = time.time()

    for epoch in range(1, max_epoch + 1):
        # 訓練用データによる学習
        train_e, train_a = mini_batch_learn(x_train, t_train, mini_batch_size=100)

        # テスト用データによる評価
        test_e, test_a = mini_batch_test(x_test, t_test, mini_batch_size=100)

        # 進捗ログは毎 epoch について出力
        print("{:<11}\t{}\t{}\t{}\t{}\t{}".format(hidden_size, epoch, train_e, train_a, test_e, test_a))
        sys.stdout.flush()

    # 実行時間の簡易的な算出
    duration = time.time() - started_at
    print("duration = {}\n\n".format(duration))
    data_variable_size['duration'].append(duration)

    # 最終 epoch における誤差と正解率を記録
    data_variable_size['train_loss'].append(train_e)
    data_variable_size['train_accuracy'].append(train_a)
    data_variable_size['test_loss'].append(test_e)
    data_variable_size['test_accuracy'].append(test_a)


hidden_size	epoch	train_loss	train_accuracy	test_loss	test_accuracy
81         	1	0.396082543992	0.895033315606	0.227381410208	0.933999978304
81         	2	0.208653101288	0.941199982862	0.17692938583	0.949499976039
81         	3	0.168193971937	0.952966646651	0.143878735593	0.960399978161
81         	4	0.142676890461	0.960216643115	0.133691288373	0.962199978828
81         	5	0.128122086059	0.964716640611	0.123796265298	0.963599977493
81         	6	0.116906136256	0.967916638752	0.11299698289	0.968499978781
81         	7	0.109803729213	0.970299970806	0.107009004795	0.969999978542
81         	8	0.103776946592	0.972233303686	0.102107790792	0.971499976516
81         	9	0.0991652312595	0.973199968239	0.0985883197468	0.972299976349
81         	10	0.0965949409921	0.97446663479	0.105309006853	0.969899975061
81         	11	0.0936900063232	0.974599969685	0.0944286079402	0.974699975252
81         	12	0.0909616821166	0.975933303138	0.0948218138539	0.972399976254
81         	13	0.089171919357	0.975933301846	0.0949801960983	0.971799976826
81         	14	0.0876385754595	0.976666634977	0.0902978002396	0.97379997611
81         	15	0.0858145673325	0.977699967921	0.0945237453049	0.972799975872
81         	16	0.0858272746516	0.97686663568	0.0924856777349	0.97459997952
81         	17	0.0840733144122	0.977683300376	0.0943561210134	0.973399977088
81         	18	0.0833632776886	0.978049967686	0.092318458783	0.973999978304
81         	19	0.0824897224394	0.978566635251	0.0897064059414	0.972899978757
81         	20	0.0821377141432	0.978383301894	0.0888363095187	0.975599976182
81         	21	0.0812880628618	0.978949967225	0.0885809043376	0.974399976134
81         	22	0.0808896454796	0.979266634186	0.0883486865391	0.975499976277
81         	23	0.0800207759906	0.979149967432	0.0908928057412	0.974899977446
81         	24	0.0800458672208	0.979283301135	0.0922391639068	0.973999974728
81         	25	0.0806316886966	0.978883300424	0.0900379785057	0.974099974632
81         	26	0.079057868716	0.979816634357	0.0873592846445	0.975399976969
81         	27	0.0790817948446	0.980116633773	0.089356432599	0.975299978256
81         	28	0.0782046813083	0.980299967329	0.0965927113523	0.972499978542
81         	29	0.0780043777389	0.980099966923	0.085471025547	0.976599978209
81         	30	0.0776588376705	0.979999966919	0.0866722953203	0.975599976182
81         	31	0.0778433806698	0.979783300857	0.0859822756087	0.975699977875
81         	32	0.0774904880269	0.980133300324	0.0871035993937	0.974499976635
81         	33	0.0774851339217	0.980183300575	0.0856929946435	0.975499978662
81         	34	0.0767618056169	0.979866634806	0.0902474218281	0.974399978518
81         	35	0.0770211240494	0.98003330042	0.0863906719093	0.975299977064
81         	36	0.0764718763127	0.980449967484	0.0865221034142	0.97489997685
81         	37	0.0763979567805	0.980533301234	0.0859675071854	0.97589997828
81         	38	0.0767085492145	0.980083300571	0.0878562329349	0.974799978137
81         	39	0.0758761550548	0.980383300483	0.0837522404082	0.976599979401
81         	40	0.0764318316399	0.980216634274	0.0868336884305	0.976199976206
81         	41	0.0757735126403	0.981166633268	0.0879706342123	0.976099976301
81         	42	0.0758696911794	0.980199967325	0.0816178569873	0.976799976826
81         	43	0.0753925541416	0.980833300451	0.0874629899801	0.975799977779
81         	44	0.0750948413555	0.980866634548	0.0842592773342	0.974699978828
81         	45	0.0755309456432	0.980266634027	0.0841287342052	0.975199978352
81         	46	0.0755539570873	0.980649966399	0.0821106989356	0.976599978209
81         	47	0.0751763957599	0.980499967138	0.0842067703337	0.975199977756
81         	48	0.0741331003979	0.981533301671	0.0867753560678	0.974899978638
81         	49	0.0750406411197	0.980766634047	0.0908812942845	0.974999975562
81         	50	0.07498737347	0.981116634309	0.0837191163201	0.977099975944
duration = 142.791613102


hidden_size	epoch	train_loss	train_accuracy	test_loss	test_accuracy
64         	1	0.416649933408	0.887983315003	0.230690641766	0.933999979496
64         	2	0.214431159484	0.940416650077	0.180412601717	0.946399978399
64         	3	0.1758590998	0.950883315007	0.164463235587	0.953699979186
64         	4	0.152826637557	0.957166644831	0.147749124677	0.956699978113
64         	5	0.138555980163	0.961249976655	0.138425158369	0.961199977994
64         	6	0.126560797188	0.964766640663	0.123697355737	0.964299977422
64         	7	0.118564143038	0.967049972117	0.118655170214	0.965899977088
64         	8	0.111604884133	0.969183304707	0.109480930578	0.968499975801
64         	9	0.106271441998	0.971683303316	0.108459147119	0.969199978113
64         	10	0.102725765215	0.971699970166	0.104042861506	0.971699977517
64         	11	0.0991391189645	0.973799969554	0.109829809954	0.968499976397
64         	12	0.0960952874211	0.974649968147	0.106499224331	0.970299980044
64         	13	0.0936812728674	0.97531663537	0.0971578527731	0.973599978089
64         	14	0.0922066390825	0.975633301636	0.0957073029527	0.972699979544
64         	15	0.0907686530116	0.976149969697	0.0937809191039	0.974899977446
64         	16	0.0892187881252	0.976566635072	0.0965835510916	0.972399976254
64         	17	0.0883682159179	0.976883301636	0.0939392659813	0.972599974871
64         	18	0.087505116413	0.977016632855	0.090941770873	0.973899976611
64         	19	0.0855352434826	0.977899967631	0.0946461701952	0.972899977565
64         	20	0.0851007128693	0.977333300412	0.0950902048778	0.97249997735
64         	21	0.0841342312687	0.978083301584	0.0943778811675	0.973299974203
64         	22	0.083491774664	0.978166634639	0.092364377483	0.973799979091
64         	23	0.0827164498158	0.978633300165	0.0932662497391	0.973699979186
64         	24	0.0821437694629	0.978833301465	0.0946451345645	0.972399978638
64         	25	0.0820050612744	0.979016634425	0.0879877446964	0.973999977112
64         	26	0.0816033064543	0.979049966534	0.0876845676429	0.975099975467
64         	27	0.0810480189727	0.97923330158	0.090159659083	0.974499977827
64         	28	0.0806551008485	0.97879996717	0.0852121498156	0.975499976277
64         	29	0.0803834103808	0.979116633932	0.0890870246396	0.973899977803
64         	30	0.0797052106913	0.979266633391	0.088395958005	0.974899980426
64         	31	0.0792308840792	0.979566632708	0.0896199468954	0.974799979925
64         	32	0.07951249348	0.979699967901	0.091862710868	0.974499978423
64         	33	0.0791511546075	0.979349967837	0.089784125071	0.974499977827
64         	34	0.0786530213648	0.980049968561	0.0904386121407	0.972899978161
64         	35	0.0783273200706	0.980066633423	0.0888219936029	0.974899976254
64         	36	0.0779584805729	0.980049966176	0.0876133997249	0.975299978852
64         	37	0.0780296670211	0.97996663332	0.0858865363058	0.975999976993
64         	38	0.0780984818656	0.980399968127	0.0873444316164	0.97559997797
64         	39	0.0771099277741	0.980349968374	0.0890652739815	0.974299976826
64         	40	0.077163927375	0.980949965815	0.0912481539696	0.974299980998
64         	41	0.0768880850729	0.980399967432	0.0843068249873	0.975999977589
64         	42	0.0766338104227	0.980299967527	0.0948666862585	0.971799977422
64         	43	0.0768269833146	0.980616633594	0.0886571476771	0.975799977183
64         	44	0.0771524642284	0.980149968068	0.0890663208882	0.974999979138
64         	45	0.0772539986484	0.979666633109	0.08734173876	0.975099979043
64         	46	0.0764183218001	0.980349966486	0.0906357711181	0.973799973726
64         	47	0.075828447137	0.980633299649	0.0878208359377	0.973999977708
64         	48	0.0765714425966	0.980649967293	0.0881093963562	0.975099979043
64         	49	0.075667183899	0.980766633948	0.086423337817	0.975099978447
64         	50	0.0756861553869	0.980449968278	0.0832524683652	0.975399977565
duration = 135.528563023


hidden_size	epoch	train_loss	train_accuracy	test_loss	test_accuracy
49         	1	0.42496985803	0.888349980464	0.242321558828	0.932399976254
49         	2	0.22601304402	0.935899982949	0.195378957	0.944499976039
49         	3	0.185571525575	0.948016648988	0.166418279842	0.951999980211
49         	4	0.16151101031	0.955016646286	0.150957076391	0.958099978566
49         	5	0.145067841696	0.959633309742	0.140809050198	0.959699975252
49         	6	0.133112934635	0.963449974656	0.129462629196	0.964699976444
49         	7	0.123139275834	0.965949974259	0.125365797537	0.964799975157
49         	8	0.11689688546	0.967866639197	0.118028486231	0.966199975014
49         	9	0.111315790604	0.970266637901	0.113782374	0.96769997716
49         	10	0.106565351719	0.971416637798	0.111730985863	0.968499979377
49         	11	0.104196631551	0.971766637166	0.110541290201	0.96839997828
49         	12	0.10079837638	0.972633303006	0.108811041638	0.969699978828
49         	13	0.0987392457295	0.973533302446	0.104251633836	0.970599975586
49         	14	0.0966452582118	0.974283301731	0.105567925144	0.970799976587
49         	15	0.0946273806598	0.974616634448	0.114098234684	0.966299977303
49         	16	0.0938237789242	0.974966635009	0.103193046476	0.970099975467
49         	17	0.0915877633666	0.975633302232	0.102760310778	0.970299980044
49         	18	0.0909569542545	0.975683301489	0.105143371872	0.969199973941
49         	19	0.0899352277722	0.976349969705	0.097924290942	0.972299977541
49         	20	0.0891852918112	0.976066634258	0.107872928116	0.971399976015
49         	21	0.088495758092	0.975883301298	0.0982207782264	0.972299977541
49         	22	0.0874076782881	0.976349967817	0.0949498594226	0.972899977565
49         	23	0.0870121641954	0.977299967209	0.0960833998886	0.97159997642
49         	24	0.0862918074988	0.977016634345	0.0977815716504	0.9739999789
49         	25	0.0859810917669	0.977249968847	0.0971644257684	0.972099974155
49         	26	0.0848159928651	0.977233301103	0.0945460003568	0.97309997499
49         	27	0.0849584223672	0.977566634218	0.104150670515	0.970599977374
49         	28	0.0855145881604	0.976983302633	0.0947418735456	0.974499977827
49         	29	0.0841674096261	0.977933301429	0.0960233874212	0.972999977469
49         	30	0.0843783575855	0.977783302665	0.0943815618521	0.973099976778
49         	31	0.0831626999006	0.978116633395	0.0972264610347	0.973699975014
49         	32	0.0830871366616	0.977499967913	0.0992042798526	0.972599979639
49         	33	0.0833607416507	0.977649967074	0.0940540234768	0.974199978113
49         	34	0.0825972068869	0.978116634587	0.0937696036941	0.973899976611
49         	35	0.0822923145847	0.978166634142	0.100606777801	0.971899975538
49         	36	0.0823770637934	0.978266634246	0.0977699727356	0.971399977207
49         	37	0.082097275015	0.978266633948	0.0920142753702	0.974199979305
49         	38	0.0821757924339	0.978133299748	0.0954888190748	0.974299978018
49         	39	0.0818349010187	0.978216633598	0.0956128694536	0.97269997716
49         	40	0.0816750141637	0.978416633209	0.100299485126	0.972999978065
49         	41	0.0812510922644	0.978066634138	0.0953299640305	0.973599976897
49         	42	0.0814067039763	0.978433300455	0.0919573889126	0.974799978733
49         	43	0.0817666132345	0.979149967233	0.10017958059	0.971999977231
49         	44	0.0810944851022	0.978916634023	0.0949980109895	0.973799977899
49         	45	0.0811356360869	0.978499969145	0.0963481318788	0.973699976206
49         	46	0.0809581809429	0.979016634822	0.0927390293288	0.974199978113
49         	47	0.081217255524	0.978766633272	0.0936506439734	0.973899976015
49         	48	0.0799795160567	0.978849967221	0.0906430186809	0.974099975824
49         	49	0.0807688882388	0.978933300376	0.0919841882575	0.973099975586
49         	50	0.0790336362117	0.979066633681	0.0975224739802	0.974199977517
duration = 133.285835028


hidden_size	epoch	train_loss	train_accuracy	test_loss	test_accuracy
36         	1	0.475990780691	0.874216647049	0.271954444367	0.925099977851
36         	2	0.25314788904	0.930083315174	0.221098707383	0.937499979138
36         	3	0.215997224543	0.941583316525	0.194586481168	0.94629997611
36         	4	0.193616341874	0.947099981109	0.176769094658	0.950299976468
36         	5	0.177212543227	0.950183314284	0.16454985701	0.952699978352
36         	6	0.163017296636	0.954199979603	0.162712873737	0.953499978781
36         	7	0.154193916246	0.957116644482	0.149367428483	0.958599979877
36         	8	0.14579907241	0.959166644017	0.143509159666	0.958899977803
36         	9	0.139478647523	0.960716642737	0.134362456482	0.962599979043
36         	10	0.134145660674	0.963033308387	0.132495961334	0.963099978566
36         	11	0.130169932581	0.963916640878	0.131258318406	0.962899976969
36         	12	0.125230064634	0.965916639368	0.126545584546	0.965499976277
36         	13	0.122042613123	0.9667499729	0.122325576385	0.965599976778
36         	14	0.119493897775	0.967283306519	0.120381987276	0.967599976659
36         	15	0.117044405794	0.967866638998	0.116560331448	0.966299976707
36         	16	0.114027758806	0.969616637429	0.120379999704	0.965299975872
36         	17	0.112441866609	0.968933305144	0.11671699458	0.965599977374
36         	18	0.110619931569	0.970099970798	0.116183796106	0.968499976397
36         	19	0.108900721275	0.970199971696	0.113540767794	0.967199977636
36         	20	0.107721300811	0.970699970524	0.114730726646	0.967799978852
36         	21	0.106396092869	0.971366636852	0.114174881084	0.967599975467
36         	22	0.105227247085	0.971399970154	0.114213561502	0.967399978042
36         	23	0.103057515882	0.972783303261	0.110368144945	0.968299978375
36         	24	0.102341988354	0.972633303006	0.109737752038	0.968799975514
36         	25	0.101430519801	0.973266635636	0.112675888422	0.966999977827
36         	26	0.100538451687	0.972683304548	0.105079575023	0.969899974465
36         	27	0.0993889217265	0.973283302287	0.112987229805	0.967699976563
36         	28	0.0985811650536	0.974283302327	0.113900280667	0.967799977064
36         	29	0.0982887638702	0.974583302935	0.105360423562	0.970199976563
36         	30	0.0965411558592	0.974183301628	0.103383705399	0.970399976969
36         	31	0.0965134245375	0.97464996924	0.108629129024	0.967899976969
36         	32	0.0958548901882	0.974883302947	0.107681749465	0.967299975157
36         	33	0.0951251893894	0.974933301111	0.103020631187	0.970399977565
36         	34	0.0941377596309	0.974666636189	0.103891561823	0.970399975777
36         	35	0.0940112734338	0.975166635315	0.103730799127	0.971199977994
36         	36	0.0927969362276	0.975616635581	0.105380800811	0.971799973845
36         	37	0.0928144451976	0.975549967686	0.102494736712	0.971299976707
36         	38	0.0932559388721	0.975899968843	0.0995676814159	0.971199977398
36         	39	0.0920473308613	0.97586663574	0.101341183661	0.971699978709
36         	40	0.0919614619296	0.975266635815	0.100245403277	0.970999978185
36         	41	0.0911848847164	0.97599996835	0.107457676777	0.96879997611
36         	42	0.0914723600975	0.975349968672	0.101603315324	0.97129997611
36         	43	0.091811785223	0.975949967901	0.097552254058	0.97309997499
36         	44	0.0901075276919	0.976399968068	0.100103718478	0.971399975419
36         	45	0.0905388320424	0.975949968298	0.0984084632806	0.971699978113
36         	46	0.090290598652	0.976183301906	0.102870702455	0.970199976563
36         	47	0.0901542076034	0.976299967666	0.0993985608453	0.971799978614
36         	48	0.0894114552097	0.976799967786	0.0986704594339	0.971899976134
36         	49	0.0896843817271	0.976649968127	0.0988923824695	0.972999977469
36         	50	0.08937333229	0.976533302069	0.100244459512	0.973299975395
duration = 129.395678043


hidden_size	epoch	train_loss	train_accuracy	test_loss	test_accuracy
25         	1	0.521608981416	0.86154998082	0.288755603619	0.920199978352
25         	2	0.273924778365	0.924049983621	0.242476443462	0.931299978495
25         	3	0.234694349517	0.935183316867	0.219494108194	0.938199979067
25         	4	0.210912742428	0.941933316588	0.200830161804	0.943299977183
25         	5	0.194231184945	0.946099981765	0.188685285938	0.948299980164
25         	6	0.180756878064	0.950166646441	0.179022925352	0.949699978828
25         	7	0.171039653147	0.952316647371	0.172092537377	0.952099980116
25         	8	0.162807763573	0.954666646222	0.162008413305	0.953799977899
25         	9	0.15529394875	0.95678331097	0.155167303248	0.956799979806
25         	10	0.149086981031	0.959249975681	0.150640704087	0.957199975848
25         	11	0.143121381675	0.960349976023	0.146084262608	0.958299976587
25         	12	0.138110712059	0.962133309146	0.141355824764	0.959499974847
25         	13	0.133669628277	0.963333309094	0.138840296187	0.962399977446
25         	14	0.130756692116	0.963766641219	0.136628973263	0.962499976754
25         	15	0.127892538998	0.964449974994	0.131345929138	0.963099974394
25         	16	0.12531523337	0.964766641855	0.133481743152	0.962499976754
25         	17	0.121696379731	0.966266640425	0.126084727494	0.96469997704
25         	18	0.120192172304	0.966699973742	0.128695115419	0.965099975467
25         	19	0.11786666913	0.967799972395	0.122195038735	0.966799978018
25         	20	0.116646764018	0.967599972785	0.124614894255	0.965199975967
25         	21	0.114709556041	0.968649971584	0.12216195731	0.964599977732
25         	22	0.113516146895	0.968866637647	0.121930467593	0.964899976254
25         	23	0.112358221151	0.968916638196	0.121550426083	0.964699977636
25         	24	0.111565529375	0.969416638513	0.12364323678	0.964599978328
25         	25	0.110662601137	0.969333304763	0.118663666295	0.966499976516
25         	26	0.109711458289	0.96971663783	0.119770371388	0.965999978781
25         	27	0.108628733292	0.970383305053	0.119036038083	0.96519997716
25         	28	0.108629336348	0.969949971338	0.117950360232	0.966299978495
25         	29	0.107351606019	0.970783304671	0.119075481419	0.96619997859
25         	30	0.107302479154	0.970549971461	0.117197259031	0.967099978924
25         	31	0.107123492894	0.970249970953	0.113452887989	0.968199978471
25         	32	0.106382296405	0.970833304425	0.119274460992	0.965999977589
25         	33	0.10602434556	0.971049971183	0.11763701346	0.966999977827
25         	34	0.104664448677	0.971333304147	0.116707304961	0.96559997797
25         	35	0.10468133547	0.971349970798	0.112774270696	0.968299976587
25         	36	0.104208168546	0.971766637365	0.116307280178	0.967899976969
25         	37	0.103987163107	0.971199970643	0.113555772868	0.968699976802
25         	38	0.103649802019	0.97158330361	0.115632669583	0.967599977851
25         	39	0.103373377177	0.971883304814	0.116671458352	0.967599977255
25         	40	0.102564478423	0.972633302907	0.111340349711	0.967499977946
25         	41	0.102500582083	0.972249969939	0.110896882901	0.968599978089
25         	42	0.10274475726	0.972233303487	0.109649408703	0.968499977589
25         	43	0.101588968467	0.972299969792	0.115415577048	0.96749997735
25         	44	0.101412783762	0.972899969518	0.109313287467	0.967999976873
25         	45	0.101330799318	0.972783302168	0.111032669623	0.967999976277
25         	46	0.101127768196	0.972849969268	0.110991472681	0.968099976778
25         	47	0.100999826826	0.972716635962	0.112516805672	0.96709997654
25         	48	0.100892505016	0.972999970118	0.109510628288	0.97029997766
25         	49	0.100326466868	0.97229997009	0.108081802137	0.969599977732
25         	50	0.0996287064285	0.973616636395	0.110670211208	0.969299974442
duration = 119.104342937


hidden_size	epoch	train_loss	train_accuracy	test_loss	test_accuracy
16         	1	0.592698701099	0.837816647689	0.302083058134	0.917599982619
16         	2	0.284472962767	0.920799982647	0.252375121722	0.929499978423
16         	3	0.249509213952	0.930683316787	0.230511852102	0.937599978447
16         	4	0.232028322779	0.933799982965	0.219929700019	0.938599980474
16         	5	0.220180080322	0.937566649516	0.218781421632	0.937599979639
16         	6	0.212799863331	0.939999982218	0.207823781306	0.939399977922
16         	7	0.205216547772	0.941966648002	0.203349125478	0.942999978662
16         	8	0.200094625254	0.944133315682	0.198191204751	0.941999979615
16         	9	0.195273961251	0.944899982413	0.193193631349	0.945099981427
16         	10	0.190521811477	0.946833314399	0.18848364464	0.946999979019
16         	11	0.187490786177	0.947516647677	0.187152695069	0.947199978232
16         	12	0.184505761024	0.947933314244	0.187395842094	0.947999979258
16         	13	0.180401333359	0.948916647931	0.186477743844	0.949199974537
16         	14	0.178098841943	0.950083313783	0.180751724266	0.948199979067
16         	15	0.175009684867	0.95089998126	0.178736979933	0.948799979687
16         	16	0.172695825553	0.951916647553	0.173499521445	0.950399978757
16         	17	0.170265760732	0.952333313028	0.176059009028	0.94999998033
16         	18	0.168426889076	0.952666647236	0.17377933451	0.948999978304
16         	19	0.165933817786	0.953616645734	0.173699509893	0.951199979782
16         	20	0.164604686641	0.954516647458	0.171596231773	0.951299980283
16         	21	0.162268086548	0.953983313143	0.167989675952	0.952499979734
16         	22	0.160512511277	0.954749979277	0.166553553133	0.953599978089
16         	23	0.158364239614	0.956416645646	0.166439946541	0.950299980044
16         	24	0.157335969284	0.955099978745	0.160823824964	0.953199979663
16         	25	0.155999369603	0.956249978046	0.16374286809	0.953599982262
16         	26	0.154773398619	0.956449978848	0.162713578879	0.952599979639
16         	27	0.153050602519	0.956833310823	0.161768940929	0.954499979615
16         	28	0.1523436388	0.957249977589	0.160057133129	0.953799979687
16         	29	0.151902274067	0.957349978685	0.157392963963	0.955899979472
16         	30	0.150760911833	0.958133311073	0.157074658405	0.954999978542
16         	31	0.150044757755	0.958299977481	0.156240226631	0.955999979973
16         	32	0.149614738456	0.958399976889	0.156685252744	0.953999980688
16         	33	0.148363801216	0.95848331064	0.158817891693	0.954599981904
16         	34	0.147894677576	0.958766644796	0.156351987077	0.955199979544
16         	35	0.147041941471	0.959233310123	0.155217946498	0.955599982142
16         	36	0.147216913179	0.958783311148	0.155078442828	0.953399981856
16         	37	0.146468356115	0.959083309968	0.150124904611	0.956199980974
16         	38	0.146006768439	0.959149977068	0.156810603691	0.95299998045
16         	39	0.145642107148	0.958866644601	0.153719958966	0.954699979424
16         	40	0.145108838473	0.959433309933	0.154707880723	0.956399978995
16         	41	0.144858170642	0.959799975852	0.152054326017	0.955699979663
16         	42	0.144926028823	0.959216643671	0.152334149359	0.955699980259
16         	43	0.144456894721	0.959516643385	0.154372081785	0.956599978209
16         	44	0.143522921739	0.9598499765	0.155528011341	0.955099980831
16         	45	0.144661245228	0.959349976778	0.157224352071	0.953299978971
16         	46	0.143984464481	0.959666643341	0.153887000745	0.955499979854
16         	47	0.142926530199	0.959383310477	0.153557347101	0.955699980259
16         	48	0.143312502224	0.960233309766	0.15385866784	0.956399978399
16         	49	0.142816281555	0.960099977255	0.149932303722	0.957499978542
16         	50	0.142874155442	0.959649977088	0.15262463226	0.955399979949
duration = 119.626075983


hidden_size	epoch	train_loss	train_accuracy	test_loss	test_accuracy
9          	1	0.818656118015	0.752916648195	0.371229402311	0.899699982405
9          	2	0.347695696205	0.902599981527	0.305832790583	0.914399981499
9          	3	0.307536870763	0.913833314975	0.287019745894	0.920599979758
9          	4	0.292015001973	0.917383315265	0.279503905214	0.923099979758
9          	5	0.282326334082	0.92033331732	0.271882737838	0.925899978876
9          	6	0.276010201064	0.922599982719	0.267324015908	0.926599978805
9          	7	0.270871964333	0.924866648118	0.264522304684	0.927299978733
9          	8	0.266869520495	0.925149983168	0.265096761417	0.927099981308
9          	9	0.264471063068	0.92618331631	0.261842615288	0.927999978662
9          	10	0.261469816603	0.927549983164	0.259355878271	0.926199977994
9          	11	0.259488549878	0.92784998248	0.256749682594	0.926799977422
9          	12	0.257635269997	0.928316649695	0.25626241168	0.929299976826
9          	13	0.255620451532	0.928899983068	0.260957205975	0.92589997828
9          	14	0.25482062038	0.929066648285	0.257922151154	0.926799978614
9          	15	0.253305883047	0.92921664993	0.255847413447	0.927199978828
9          	16	0.251593413614	0.929833316207	0.258977433397	0.927799979448
9          	17	0.25028087421	0.929916649361	0.25329032287	0.930399979353
9          	18	0.249563378505	0.929249981046	0.252542307749	0.926799976826
9          	19	0.248701446628	0.93041664958	0.256687079854	0.926699975729
9          	20	0.247404867237	0.930683316092	0.251968266303	0.92869997859
9          	21	0.24693934648	0.930683315098	0.249477285519	0.931799978614
9          	22	0.246106656964	0.930966649652	0.251796464976	0.930899977684
9          	23	0.245129487601	0.931116649508	0.249356075656	0.930399976373
9          	24	0.244708089419	0.931883315047	0.256087167142	0.926199976802
9          	25	0.244231025216	0.931049982707	0.250778504107	0.929599978924
9          	26	0.243249550772	0.93171664993	0.247441170095	0.932199976444
9          	27	0.241764034455	0.931233315368	0.251725307954	0.926399977803
9          	28	0.241538659781	0.932283315857	0.248131537279	0.929299978614
9          	29	0.240854494373	0.93224998186	0.251995971082	0.929499979019
9          	30	0.240782345446	0.931699982584	0.245075793937	0.932199978232
9          	31	0.240070650789	0.932533315519	0.249905576669	0.929499977827
9          	32	0.239268037416	0.932633317312	0.249562492426	0.928199977279
9          	33	0.238447342999	0.933733315865	0.250130536882	0.927599977851
9          	34	0.238299275363	0.932833316525	0.251445378168	0.927899976969
9          	35	0.237109052241	0.932583316366	0.24843442766	0.930899978876
9          	36	0.236983723703	0.93381664892	0.245170613984	0.930599979162
9          	37	0.235480831477	0.933883316418	0.247532728836	0.929599978328
9          	38	0.235663501173	0.933833315273	0.24737455776	0.930299978256
9          	39	0.234135231239	0.933833316366	0.245096690897	0.930499978662
9          	40	0.234543845157	0.934316649139	0.24585830708	0.930299976468
9          	41	0.233676121285	0.934883314868	0.245880237129	0.93019998014
9          	42	0.233009988802	0.934966649711	0.243903771387	0.930999978781
9          	43	0.23281694	0.935099982023	0.246915497445	0.92989997983
9          	44	0.232201331357	0.934533316195	0.24271708387	0.930999977589
9          	45	0.231652550784	0.935233315726	0.24878255289	0.9289999789
9          	46	0.230845863782	0.935316649477	0.243454007944	0.931099976897
9          	47	0.231136437071	0.935849982003	0.24541374648	0.929199980497
9          	48	0.229606743579	0.935899982651	0.242632134054	0.932299978733
9          	49	0.229666709527	0.935249983172	0.244335451946	0.930299978256
9          	50	0.228779803676	0.936316648523	0.241723142639	0.929899978042
duration = 118.98732996



In [27]:
# 結果のデータフレームを作成
df_variable_size = pd.DataFrame(data_variable_size)
df_variable_size


Out[27]:
duration hidden_size test_accuracy test_loss train_accuracy train_loss
0 142.791613 81 0.9771 0.083719 0.981117 0.074987
1 135.528563 64 0.9754 0.083252 0.980450 0.075686
2 133.285835 49 0.9742 0.097522 0.979067 0.079034
3 129.395678 36 0.9733 0.100244 0.976533 0.089373
4 119.104343 25 0.9693 0.110670 0.973617 0.099629
5 119.626076 16 0.9554 0.152625 0.959650 0.142874
6 118.987330 9 0.9299 0.241723 0.936317 0.228780

隠れ層のニューロン数に対する誤差の変化は以下のとおりです。


In [28]:
plt.figure()
ax = df_variable_size.plot(x="hidden_size", y="train_loss", style="k-")
df_variable_size.plot(x="hidden_size", y="test_loss", style="k--", dashes=(3, 1.5), ax=ax)
ax.set_xlabel(u'隠れ層のニューロン数')
ax.set_ylabel(u'誤差')
ax.legend(labels=[u'訓練', u'テスト'], loc=u'upper left')
plt.ylim(ymin=0)
plt.tight_layout()


<matplotlib.figure.Figure at 0x1312b8750>

隠れ層のニューロン数に対する正解率の変化は以下のとおりです。


In [29]:
plt.figure()
ax = df_variable_size.plot(x="hidden_size", y="train_accuracy", style="k-")
df_variable_size.plot(x="hidden_size", y="test_accuracy", style="k--", dashes=(3, 1.5), ax=ax)
ax.set_xlabel(u'隠れ層のニューロン数')
ax.set_ylabel(u'正解率')
ax.legend(labels=[u'訓練', u'テスト'], loc=u'lower left')
plt.ylim(ymin=0.8, ymax=1)
plt.tight_layout()


<matplotlib.figure.Figure at 0x131517150>

隠れ層のニューロン数に対する計算時間の変化は以下のとおりです。


In [30]:
plt.figure()
ax = df_variable_size.plot(x="hidden_size", y="duration", style="k-", legend=False)
ax.set_xlabel(u'隠れ層のニューロン数')
ax.set_ylabel(u'学習時間 (秒)')
plt.tight_layout()


<matplotlib.figure.Figure at 0x1321fa910>

以上の結果は1回の実行で得られた計測値を示しています。同じ実験を何度も実行し、計測値の平均と標準偏差を求めると、もっと正確な傾向を把握できます。