Minist예제를 수행하면 알겠지만, Train에 생각보다는 꽤 많은 시간이 소요됩니다. 이 이유만이 아니라 평가시에는 trainnig후에 model의 parameter를 저장했다가 평가시에는 그 parameter를 불러들여서 사용하는 것이 일반적입니다.
여기에 사용되는 함수는 torch.save, torch.load와 model.state_dict(), model.load_state_dict()입니다. 사실 4장의 tutorial의 마지막에 torch.save를 이용하여 model parameter를 저장을 했습니다. 따라서 이번 장에서는 train과정 없이 save된 file로 부터 model의 parameter를 복구하여 사용해보도록 하겠습니다.
## training end
torch.save(model.state_dict(), checkpoint_filename)
## evaluating start
checkpoint = torch.load(checkpoint_filename)
model.load_state_dict(checkpoint)
In [1]:
%matplotlib inline
In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from torch.autograd import Variable
import matplotlib.pyplot as plt
is_cuda = torch.cuda.is_available() # cuda 사용가능시, True
checkpoint_filename = 'minist.ckpt'
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=False, transform=transforms.ToTensor()),
batch_size=100, shuffle=False)
In [3]:
class MnistModel(nn.Module):
def __init__(self):
super(MnistModel, self).__init__()
# input is 28x28
# padding=2 for same padding
self.conv1 = nn.Conv2d(1, 32, 5, padding=2)
# feature map size is 14*14 by pooling
# padding=2 for same padding
self.conv2 = nn.Conv2d(32, 64, 5, padding=2)
# feature map size is 7*7 by pooling
self.fc1 = nn.Linear(64*7*7, 1024)
self.fc2 = nn.Linear(1024, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), 2)
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, 64*7*7) # reshape Variable
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x)
model = MnistModel()
if is_cuda : model.cuda()
In [4]:
checkpoint = torch.load(checkpoint_filename)
model.load_state_dict(checkpoint)
In [5]:
model.eval()
correct = 0
for image, target in test_loader:
if is_cuda : image, target = image.cuda(), target.cuda()
image, target = Variable(image, volatile=True), Variable(target)
output = model(image)
prediction = output.data.max(1)[1]
correct += prediction.eq(target.data).sum()
print('\nTest set: Accuracy: {:.2f}%'.format(100. * correct / len(test_loader.dataset)))
In [6]:
model.state_dict().keys()
Out[6]:
In [7]:
plt.rcParams["figure.figsize"] = [8, 4]
weight = model.state_dict()['conv1.weight']
wmax, wmin = torch.max(weight), torch.min(weight)
gridimg = torchvision.utils.make_grid(weight).cpu().numpy().transpose((1,2,0))
plt.imshow(gridimg[:,:,0], vmin = wmin, vmax =wmax, interpolation='nearest', cmap='seismic') # gridimg[:, :, 0]는 한 color channel을 출력
Out[7]:
In [8]:
plt.rcParams["figure.figsize"] = [8, 8]
weight = model.state_dict()['conv2.weight'] # 64 x 32 x 5 x 5
weight = weight[:, 0:1, :, :] # 64 x 1 x 5 x 5
wmax, wmin = torch.max(weight), torch.min(weight)
gridimg = torchvision.utils.make_grid(weight).cpu().numpy().transpose((1,2,0))
plt.imshow(gridimg[:,:,0], vmin = wmin, vmax =wmax, interpolation='nearest', cmap='seismic') # gridimg[:, :, 0]는 한 color channel을 출력
Out[8]: