08_NN_MNIST


In [6]:
# Pytorch Library
import torch
import torch.nn.init
from torch.autograd import Variable

torch.manual_seed(777)  # reproducibility


Out[6]:
<torch._C.Generator at 0x7f368006a258>

In [7]:
import torchvision.utils as utils
import torchvision.datasets as dsets
import torchvision.transforms as transforms

In [27]:
# Other Python Library
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import random

Loading MNIST dataset


In [9]:
# MNIST dataset
mnist_train = dsets.MNIST(root='data/',
                          train=True,
                          transform=transforms.ToTensor(),
                          download=True)

mnist_test = dsets.MNIST(root='data/',
                         train=False,
                         transform=transforms.ToTensor(),
                         download=True)

In [28]:
# plot one example
print(mnist_train.train_data.size())                 # (60000, 28, 28)
print(mnist_train.train_labels.size())               # (60000)

idx = 0
plt.imshow(mnist_train.train_data[idx,:,:].numpy(), cmap='gray')
plt.title('%i' % mnist_train.train_labels[idx])


torch.Size([60000, 28, 28])
torch.Size([60000])
Out[28]:
<matplotlib.text.Text at 0x7f35dbf855c0>

Data Loader (= 전체 Data를 쪼개서 Batch Size 단위로 불러 오는 역할)


In [10]:
# Hyper-parameters
batch_size = 100

# dataset loader
data_loader = torch.utils.data.DataLoader(dataset=mnist_train,
                                          batch_size=batch_size,
                                          shuffle=True,
                                          num_workers=1)

In [19]:
def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

# #예를 들어서.. data_loader로 부터 데이터 뭉치를 하나만 불러 오자
batch_images, batch_labels = next(iter(data_loader))

print(batch_images.size())
print(batch_labels.size())

# show images and print labels
imshow(utils.make_grid(batch_images))
batch_labels.numpy()

# 실제 사용 예시 for loop를 이용해서 image와 label을 읽어 온다.
# for batch_images, batch_labels in data_loader:
#     print(img.size())
#     print(label)

#enumerate를 사용하는 경우 image와 label을 읽는 것 뿐만 아니라 iter도 같이 계산
# for i, (batch_images, batch_labels) in enumerate(data_loader):
#     print(batch_images.size())
#     print(batch_labels)


torch.Size([100, 1, 28, 28])
torch.Size([100])
Out[19]:
array([8, 8, 2, 7, 8, 6, 4, 3, 7, 3, 3, 6, 4, 1, 1, 5, 9, 4, 2, 0, 0, 6, 9,
       4, 0, 3, 9, 2, 6, 4, 5, 7, 4, 9, 8, 4, 7, 6, 4, 1, 8, 6, 7, 7, 2, 5,
       4, 1, 3, 7, 4, 8, 2, 8, 9, 3, 8, 8, 0, 6, 8, 0, 4, 3, 8, 2, 4, 6, 2,
       6, 4, 9, 7, 7, 0, 6, 8, 2, 6, 9, 3, 9, 3, 4, 4, 1, 2, 1, 0, 1, 2, 4,
       4, 1, 5, 7, 8, 9, 3, 8])

Define Neural Network Model


In [48]:
# Neural Network
linear1 = torch.nn.Linear(784, 512, bias=True)
linear2 = torch.nn.Linear(512, 10, bias=True)
relu = torch.nn.ReLU()
#sigmoid = torch.nn.Sigmoid()

# model
model = torch.nn.Sequential(linear1, relu, linear2)   

#model.load_state_dict(torch.load('NN.pkl'))  # Load the Trained Model
print(model)


Sequential (
  (0): Linear (784 -> 512)
  (1): ReLU ()
  (2): Linear (512 -> 10)
)

Cost 함수 정의


In [41]:
# Softmax 함수가 Cost를 계산할 때 내장되어 있다.
cost_func = torch.nn.CrossEntropyLoss()

Model 학습


In [51]:
# Hyper-parameters
learning_rate = 0.001 
training_epochs = 5

# Adam Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Train model
for epoch in range(training_epochs):
    avg_cost = 0
    total_batch = len(mnist_train) // batch_size

    for i, (batch_images, batch_labels) in enumerate(data_loader):
        
        # 이미지를 [batch_size x 784] size 행렬로 변환
        
        X = Variable(batch_images.view(-1, 28 * 28))
        Y = Variable(batch_labels)        # label is not one-hot encoded

        optimizer.zero_grad()             # Zero Gradient Container
        Y_prediction = model(X)           # Forward Propagation
        cost = cost_func(Y_prediction, Y) # compute cost
        cost.backward()                   # compute gradient
        optimizer.step()                  # gradient update

        avg_cost += cost / total_batch

    print("[Epoch: {:>4}] cost = {:>.9}".format(epoch + 1, avg_cost.data[0]))

print('Learning Finished!')
torch.save(model.state_dict(), 'NN.pkl')  # Save the Model


[Epoch:    1] cost = 0.291562766
[Epoch:    2] cost = 0.114509255
[Epoch:    3] cost = 0.0747233555
[Epoch:    4] cost = 0.0530239083
[Epoch:    5] cost = 0.0399747863
Learning Finished!

In [50]:
model.state_dict()


Out[50]:
OrderedDict([('0.weight', 
              -2.3496e-02  2.9788e-02 -1.7264e-03  ...   1.1420e-02  3.5596e-03 -2.0659e-02
               2.5164e-03 -1.2785e-02  3.4308e-02  ...  -3.2980e-02  3.2557e-02 -2.6513e-02
              -2.1505e-02 -2.0751e-02  1.2887e-04  ...  -1.3800e-02  1.3477e-03 -2.6379e-02
                              ...                   ⋱                   ...                
              -1.5719e-02 -1.9760e-02  2.7627e-02  ...   3.4418e-02  2.4713e-02  1.6831e-02
              -9.7556e-03 -2.8041e-02  1.8751e-02  ...  -2.5325e-02  1.6545e-02  1.0813e-03
              -3.5254e-02  2.2254e-02  3.4663e-02  ...  -1.6771e-02 -2.2863e-02 -2.9022e-02
              [torch.FloatTensor of size 512x784]), ('0.bias', 
              1.00000e-02 *
               -2.3553
               -2.6453
                3.0343
                2.2883
               -3.1233
               -2.1367
                1.5566
               -2.9571
                3.3401
               -3.1516
               -2.1481
               -1.2091
               -0.0850
                2.6428
               -0.8040
               -2.8000
               -1.0764
                3.2807
               -2.2573
               -0.3775
               -2.9263
               -1.5429
               -2.5113
               -3.4202
                0.8385
               -3.5181
               -2.6820
               -1.8628
                1.6112
                1.8237
               -1.9087
                1.4123
                2.0861
               -2.3882
                2.4071
               -0.8165
                3.3986
               -1.3187
                2.5700
               -1.6501
               -0.6016
               -0.5085
                1.8953
               -2.8064
                1.1396
               -2.3410
               -0.1951
               -3.2319
                2.9314
               -3.2882
               -1.2578
                0.2191
               -0.1490
               -3.0703
                3.2597
                0.3758
                0.1239
                2.3969
               -3.2709
                1.0702
               -0.8069
               -2.8044
               -3.4624
                2.0749
               -0.2887
                0.1121
               -3.5021
               -2.0616
                0.5668
               -0.9125
                1.8797
                3.1180
               -0.9413
                1.5000
                1.3766
                0.5141
                2.4657
                3.2890
               -2.4170
                1.2798
                2.4782
                3.5350
                2.6479
                0.0432
                0.2869
               -2.4598
                1.2100
                0.1626
                0.6305
                0.0054
               -2.4442
               -0.5276
                0.9066
                3.4227
               -1.0035
                2.1802
               -2.0636
               -0.0076
                0.0530
                2.4854
               -0.6937
                1.8330
                3.3897
                3.5056
                0.1438
               -2.9699
               -2.2182
                0.7438
               -0.9424
                3.1882
               -1.3609
               -1.3448
               -0.3865
                0.8751
                3.2298
               -3.4289
                1.8370
                0.9028
               -2.8113
               -0.6247
               -2.7277
               -1.3583
                2.0884
                2.1453
               -2.5761
               -2.6445
               -2.0787
                1.7068
                1.5705
               -2.1621
                0.7645
                0.2741
               -2.7793
                1.2225
                0.2663
               -0.0852
                3.4696
                1.3617
               -1.2560
               -2.5454
                2.3261
               -3.4097
                3.1463
                3.4688
                0.0341
                2.0143
                2.0194
               -2.2106
                1.3329
               -1.1817
               -3.1543
                2.5460
                0.6520
               -0.0107
               -3.3803
                0.0391
                2.7123
                3.1975
               -0.2115
                2.8922
               -1.9887
               -3.4945
               -2.8191
                1.2206
               -1.4639
               -2.8291
               -2.9863
                1.2345
               -0.7711
                2.4483
                0.2346
               -1.6019
                2.6920
               -3.0526
                1.9864
                1.2633
               -1.0319
               -3.4819
               -0.1132
                3.5383
               -3.0511
                1.5865
               -0.4105
                1.5911
               -1.3778
                2.0947
                2.6765
               -1.3451
                3.3554
                3.3513
               -0.9808
                2.8308
               -3.2272
                0.5691
                0.5572
                1.4578
                0.4720
               -2.0083
                0.6627
                2.1885
                2.1171
                3.4836
               -1.3903
                0.2205
                0.7160
               -1.7778
                0.4326
               -2.0349
               -0.6734
                2.1960
               -2.1789
               -0.1609
               -2.0775
                2.5277
               -1.1704
               -0.1715
                1.8529
                2.2519
                2.5294
                0.5104
               -0.7717
               -2.5538
               -2.3655
                2.9684
               -2.1508
               -1.1328
               -3.4231
                3.4485
                0.3511
               -0.8689
               -0.4363
                2.9032
                0.4320
                2.8540
                2.0288
                0.2308
                1.8907
                1.8703
               -0.4203
               -1.2873
                0.4091
                2.7461
                2.8902
               -1.4762
                2.2648
                1.0650
                0.7851
               -2.3670
                1.2870
               -0.3435
                2.6111
               -2.5615
                0.1935
                2.0102
               -0.4063
                1.5359
                0.4778
                1.1094
                0.8822
                1.3381
               -1.4139
                3.4538
               -2.2885
               -1.7481
                3.1288
               -3.3504
               -3.5228
               -2.7743
                3.4625
               -1.6650
               -2.9002
                3.5688
                0.1623
               -3.0064
               -1.4784
               -1.9098
               -0.1898
               -2.9951
               -1.5544
               -2.4125
               -0.3282
               -1.8300
                2.9710
               -1.4617
               -0.0272
               -3.4345
                1.6931
               -2.3837
               -2.3563
                2.2599
               -1.1737
               -2.2426
                2.5688
                2.3264
                0.3521
               -0.1164
               -1.0359
                0.2768
               -1.7630
               -2.5943
                2.1391
                1.5562
                3.2509
                0.7878
               -1.0313
                1.1207
               -3.2799
                2.3148
               -3.5618
               -2.7915
                2.9625
               -1.0464
               -1.3386
                2.0072
               -3.0078
                3.4366
                0.5337
                0.8686
                1.0250
               -1.6504
                3.2727
               -1.7313
                1.5070
                0.0947
                3.1590
                1.2512
               -1.9056
               -2.0119
               -3.2972
               -0.3196
               -0.2522
               -0.2826
               -1.2559
               -1.3962
               -0.4633
               -2.5335
               -2.7912
                0.0485
               -3.5297
               -2.4014
               -3.4062
               -2.7558
                1.6678
                2.7483
                1.0598
                0.0142
                0.8290
               -0.0874
               -0.1683
               -0.0610
                1.5084
               -1.7381
                1.0701
                0.7625
                2.5659
                1.0253
               -0.8440
                2.4046
               -0.9898
                3.1596
                0.6279
               -1.2795
               -0.6006
               -3.3855
                0.0570
                1.9536
                1.0478
                2.8595
               -3.1910
               -2.3758
               -3.5326
                3.0020
                1.3199
                1.7847
               -1.8730
                3.1556
               -1.7133
                2.9998
               -0.8541
                0.9468
               -3.0378
               -2.3171
               -3.1858
               -1.8240
                2.1284
               -1.0451
                2.5863
               -3.1903
                2.8594
                0.4364
               -3.3574
               -1.7679
                0.8841
                1.7672
                1.6901
               -0.6244
                1.9117
               -3.0435
                2.4238
                0.2991
               -1.2388
                2.9010
                0.3632
               -2.8417
                1.5224
               -0.3524
                2.5578
                0.7850
               -2.8999
                0.0103
               -0.8369
                1.0645
                3.5065
               -1.9673
               -1.6908
               -2.6221
               -0.7847
               -2.4204
                0.3483
               -1.4609
               -1.3369
                3.1134
                2.0878
                2.9270
               -1.4726
                1.7519
                2.7774
                2.1507
               -2.8465
                1.8537
               -2.3724
                1.1938
               -0.8638
                1.8465
                3.1385
               -3.1276
               -2.9783
               -3.4153
               -3.2074
               -2.4258
                2.0432
               -1.1392
                0.7319
               -2.6145
               -2.1197
               -1.2575
                1.6930
                1.4557
                0.7199
               -2.8050
                2.2339
                1.0211
               -0.9952
                0.0983
               -2.2851
               -2.2415
                2.5922
               -2.5172
                0.9950
                2.4192
                2.7259
                0.7804
                3.0033
               -0.3040
               -3.2007
               -2.3310
               -2.0055
                1.2702
               -1.3865
               -1.0364
               -0.6340
                1.3547
                3.3839
               -2.6511
               -0.2354
                1.4610
                2.3329
                1.0671
               -0.4548
               -2.8125
                3.2924
                0.4587
                1.3335
               -0.8581
               -2.1694
                1.4527
               -1.3863
               -3.1210
                2.5191
               -0.1687
               -0.0682
                2.6591
                1.2109
                2.1558
                0.0722
               -3.0620
               -0.6197
                3.0513
                2.6453
               -3.3090
                0.3168
               -2.8345
               -1.0068
                1.4428
                2.4972
                3.4624
               -2.8031
               -0.8754
                1.3821
               -3.5416
               -0.7748
               -1.2046
              [torch.FloatTensor of size 512]), ('2.weight', 
              1.00000e-02 *
               3.0333  2.0057 -0.9666  ...   3.6867  0.4237  3.0113
              -1.1873 -0.0427  4.1080  ...   1.8256  1.1729  1.7186
               3.5718  2.3841 -3.5690  ...  -4.0433  2.5602 -0.3159
                        ...             ⋱             ...          
               0.4649  2.7811  1.9414  ...   2.2403  1.7038 -3.8776
              -1.7295 -1.6764  3.5411  ...   0.6390  3.7691 -1.0047
              -1.7008 -2.1778  2.0022  ...   2.3867  3.8542 -0.4765
              [torch.FloatTensor of size 10x512]), ('2.bias', 
              1.00000e-02 *
               -1.7779
               -2.0617
                2.8643
               -2.1944
               -1.9716
                4.0888
               -0.2323
               -3.6444
                2.7555
               -3.0499
              [torch.FloatTensor of size 10])])

Test Dataset 이용한 모형 성능 측정


In [43]:
# Test the Model
correct = 0
total = 0
for images, labels in mnist_test:
    images = Variable(images.view(-1, 28*28))
    outputs = model(images)
    _, predicted = torch.max(outputs.data, 1)
    total += 1
    correct += (predicted == labels).sum()

print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))


Accuracy of the network on the 10000 test images: 9 %

Random Sample Test


In [49]:
# Get one and predict
r = random.randint(0, len(mnist_test) - 1)
X_single_data = Variable(mnist_test.test_data[r:r + 1].view(-1, 28 * 28).float())
Y_single_data = Variable(mnist_test.test_labels[r:r + 1])

single_prediction = model(X_single_data)

plt.imshow(X_single_data.data.view(28,28).numpy() , cmap='gray')

print("Label: ", Y_single_data.data)
print("Prediction: ", torch.max(single_prediction.data, 1)[1])


Label:  
 8
[torch.LongTensor of size 1]

Prediction:  
 3
[torch.LongTensor of size 1]


In [30]:
for i in range(20):
    weight = model[0].weight[i,:].data.view(28,28) 
    weight = (weight - torch.min(weight))/(torch.max(weight)-torch.min(weight))
    plt.imshow( weight.numpy(), cmap='gray')
    plt.show()



In [ ]: