In [ ]:
######################################################################
################### This is a line of 70 characters ##################
######################################################################
In [ ]:
import numpy
import torch
import torchvision
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as functional
import torch.optim as optim
from utils import write_file_and_close, check_control
import os
import errno
In [ ]:
global_batch_size = 128
global_resnet_n = 3
global_conv_bias = True
global_data_print_freq = 20
global_epoch_num = 200
global_cuda_available = True
global_output_filename = "out.txt"
global_control_filename = "control.txt"
global_epoch_test_freq = 1
if global_cuda_available:
os.environ["CUDA_VISIBLE_DEVICES"] = "5"
In [ ]:
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# Here per-pixel mean isn't subtracted
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
In [ ]:
trainset = torchvision.datasets.CIFAR10(
root="./data", download=True, train=True, transform=transform_train
)
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=global_batch_size, shuffle=True, num_workers=2
)
testset = torchvision.datasets.CIFAR10(
root="./data", download=True, train=False, transform=transform_test
)
testloader = torch.utils.data.DataLoader(
testset, batch_size=global_batch_size, shuffle=False, num_workers=2
)
In [ ]:
class StartBlock(nn.Module):
"""First several blocks for resnet
Only contains a single layer of conv2d and a batch norm layer
"""
def __init__(self, out_planes, kernel_size):
super(StartBlock, self).__init__()
self.out_plane = out_planes
self.kernel_size = kernel_size
self.conv = nn.Conv2d(
3, out_planes, kernel_size=kernel_size,
padding=1, bias=global_conv_bias
)
self.bn = nn.BatchNorm2d(out_planes)
def forward(self, x):
out = self.conv(x)
out = self.bn(out)
out = functional.relu(out)
return out
class BasicBlock(nn.Module):
"""Repeated blockes for resnet
Contains two conv layers, two batch norm layers and a shortcut
"""
def __init__(self, in_planes, out_planes, kernel_size, stride):
super(BasicBlock, self).__init__()
self.in_planes = in_planes
self.out_planes = out_planes
self.kernel_size = kernel_size
self.stride = stride
self.bn1 = nn.BatchNorm2d(in_planes)
self.conv1 = nn.Conv2d(
in_planes, out_planes, kernel_size=kernel_size,
stride=stride, padding=1, bias=global_batch_size
)
self.bn2 = nn.BatchNorm2d(out_planes)
self.conv2 = nn.Conv2d(
out_planes, out_planes, kernel_size=kernel_size,
padding=1, bias=global_batch_size
)
self.shortcut = nn.Conv2d(
in_planes, out_planes, kernel_size=1, stride=stride
)
def forward(self, x):
out = self.bn1(x)
out = functional.relu(out)
out = self.conv1(out)
out = self.bn2(out)
out = functional.relu(out)
out = self.conv2(out)
out += self.shortcut(x)
return out
class EndBlock(nn.Module):
"""Last several blocks for resnet
Only contains a global average pooling layer and a fully
connected layer.
"""
def __init__(self, in_planes):
super(EndBlock, self).__init__()
self.fc = nn.Linear(in_planes, 10)
def forward(self, x):
out = torch.mean(x, dim=2)
out = torch.mean(out, dim=3)
out = out.view(out.size()[0], -1)
out = self.fc(out)
return out
In [ ]:
class ResNet(nn.Module):
"""ResNet-(6n + 2)"""
def __init__(self, n):
super(ResNet, self).__init__()
self.block_list = []
self.block_list.append(StartBlock(16, 3))
for i in range(n):
self.block_list.append(BasicBlock(16, 16, 3, 1))
self.block_list.append(BasicBlock(16, 32, 3, 2))
for i in range(n - 1):
self.block_list.append(BasicBlock(32, 32, 3, 1))
self.block_list.append(BasicBlock(32, 64, 3, 2))
for i in range(n - 1):
self.block_list.append(BasicBlock(64, 64, 3, 1))
self.block_list.append(EndBlock(64))
self.blocks = nn.Sequential(*self.block_list)
def forward(self, x):
out = self.blocks(x)
return out
In [ ]:
net = ResNet(global_resnet_n)
if global_cuda_available:
net.cuda()
In [ ]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(
net.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0001
)
def lr_adjust(it):
if it < 32000:
return 0.1
elif it < 48000:
return 0.01
elif it < 64000:
return 0.001
else:
return 0.0001
In [ ]:
def train(data, info):
global net, optimizer, criterion
inputs, labels = data
inputs, labels = Variable(inputs), Variable(labels)
if global_cuda_available:
inputs, labels = inputs.cuda(), labels.cuda()
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
info[0] = loss.data[0]
info[1] = labels.size()[0]
def test(info):
global net
correct_sum = 0
total_loss_sum = 0.
total_ctr = 0
for data in testloader:
inputs, labels = data
inputs, labels = Variable(inputs), Variable(labels)
if global_cuda_available:
inputs, labels = inputs.cuda(), labels.cuda()
outputs = net(inputs)
_, predicted = torch.max(outputs.data, 1)
total_ctr += labels.size()[0]
correct_sum += (predicted == labels.data).sum()
loss = criterion(outputs, labels)
total_loss_sum += loss.data[0]
info[0] = correct_sum
info[1] = total_ctr
info[2] = total_loss_sum
In [ ]:
write_file_and_close(global_output_filename, "Cleaning...", flag = "w")
write_file_and_close(
global_output_filename,
"The length of trainloader and testloader is {:d} and {:d} resp."
.format(len(trainloader), len(testloader))
)
In [ ]:
write_file_and_close(global_output_filename, "Start training")
it = 0
for epoch in range(global_epoch_num):
if not check_control(global_control_filename):
write_file_and_close(gloabl_output_filename, "Control llost")
running_loss_sum = 0.
total_loss_sum = 0.
ctr_sum = 0
total_ctr = 0
for g in optimizer.param_groups:
g["lr"] = lr_adjust(it)
for i, data in enumerate(trainloader):
info = [0., 0]
train(data, info)
running_loss_sum += info[0]
total_loss_sum += info[0]
ctr_sum += 1
total_ctr += info[1]
if (i + 1) % global_data_print_freq == 0:
write_file_and_close(global_output_filename,
"epoch: {:d}, "
"train set index: {:d}, "
"average loss: {:.10f}"
.format(epoch, i, running_loss_sum / ctr_sum)
)
running_loss_sum = 0.0
ctr_sum = 0
it = it + 1
write_file_and_close(global_output_filename,
"Epoch {:d} finished, average loss: {:.10f}"
.format(epoch, total_loss_sum / total_ctr)
)
if (epoch + 1) % global_epoch_test_freq == 0:
write_file_and_close(global_output_filename, "Starting testing")
info = [0., 0., 0.]
test(info)
write_file_and_close(global_output_filename,
"Correct: {:d}, total: {:d}, "
"accuracy: {:.10f}, average loss: {:.10f}"
.format(info[0], info[1], info[0] / info[1], info[2] / info[1])
)
write_file_and_close(global_output_filename, "Finished testing")
In [ ]: