Save & Restore with a minist example

Minist예제를 수행하면 알겠지만, Train에 생각보다는 꽤 많은 시간이 소요됩니다. 이 이유만이 아니라 평가시에는 trainnig후에 model의 parameter를 저장했다가 평가시에는 그 parameter를 불러들여서 사용하는 것이 일반적입니다.

여기에 사용되는 함수는 torch.save, torch.load와 model.state_dict(), model.load_state_dict()입니다. 사실 4장의 tutorial의 마지막에 torch.save를 이용하여 model parameter를 저장을 했습니다. 따라서 이번 장에서는 train과정 없이 save된 file로 부터 model의 parameter를 복구하여 사용해보도록 하겠습니다.

## training end
torch.save(model.state_dict(), checkpoint_filename)

## evaluating start
checkpoint = torch.load(checkpoint_filename)
model.load_state_dict(checkpoint)

In [1]:
%matplotlib inline

1. 입력DataLoader 설정


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from torch.autograd import Variable
import matplotlib.pyplot as plt

is_cuda = torch.cuda.is_available() # cuda 사용가능시, True
checkpoint_filename = 'minist.ckpt'

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=False, transform=transforms.ToTensor()),
    batch_size=100, shuffle=False)

2. 사전 설정

* model
* loss (train을 하지 않으므로, 생략)
* opimizer (train을 하지 않으므로, 생략)

In [3]:
class MnistModel(nn.Module):
    def __init__(self):
        super(MnistModel, self).__init__()
        # input is 28x28
        # padding=2 for same padding
        self.conv1 = nn.Conv2d(1, 32, 5, padding=2)
        # feature map size is 14*14 by pooling
        # padding=2 for same padding
        self.conv2 = nn.Conv2d(32, 64, 5, padding=2)
        # feature map size is 7*7 by pooling
        self.fc1 = nn.Linear(64*7*7, 1024)
        self.fc2 = nn.Linear(1024, 10)
        
    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), 2)
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, 64*7*7)   # reshape Variable
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x)
    
model = MnistModel()
if is_cuda :  model.cuda()

3. Restore model paramter from saved file


In [4]:
checkpoint = torch.load(checkpoint_filename)
model.load_state_dict(checkpoint)

6. Predict & Evaluate

train을 하지 않았음에도 이전에 학습된 model parameter를 복원하여 정확도가 98%이상인 것을 알 수 있습니다.


In [5]:
model.eval()
correct = 0
for image, target in test_loader:
    if is_cuda :  image, target = image.cuda(), target.cuda() 
    image, target = Variable(image, volatile=True), Variable(target)
    output = model(image)
    prediction = output.data.max(1)[1]
    correct += prediction.eq(target.data).sum()

print('\nTest set: Accuracy: {:.2f}%'.format(100. * correct / len(test_loader.dataset)))


Test set: Accuracy: 98.64%

5. plot weights

model의 weight를 plot하여 봅니다.


In [6]:
model.state_dict().keys()


Out[6]:
odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias'])

In [7]:
plt.rcParams["figure.figsize"] = [8, 4]

weight = model.state_dict()['conv1.weight']
wmax, wmin = torch.max(weight), torch.min(weight)
gridimg = torchvision.utils.make_grid(weight).cpu().numpy().transpose((1,2,0))
plt.imshow(gridimg[:,:,0], vmin = wmin, vmax =wmax, interpolation='nearest', cmap='seismic') # gridimg[:, :, 0]는 한 color channel을 출력


Out[7]:
<matplotlib.image.AxesImage at 0x7f77589ff320>

In [8]:
plt.rcParams["figure.figsize"] = [8, 8]

weight = model.state_dict()['conv2.weight'] # 64 x 32 x 5 x 5
weight = weight[:, 0:1, :, :] # 64 x 1 x 5 x 5
wmax, wmin = torch.max(weight), torch.min(weight)
gridimg = torchvision.utils.make_grid(weight).cpu().numpy().transpose((1,2,0))
plt.imshow(gridimg[:,:,0], vmin = wmin, vmax =wmax, interpolation='nearest', cmap='seismic') # gridimg[:, :, 0]는 한 color channel을 출력


Out[8]:
<matplotlib.image.AxesImage at 0x7f7758991828>