In [ ]:
import numpy
import torch
import torchvision
import torchvision.transforms as transforms

from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as functional

import torch.optim as optim

import init

from utils import write_file_and_close, check_control
from utils import generate_filename

import os
import errno

In [ ]:
global_batch_size = 128
global_conv_bias = True
global_data_print_freq = 20
global_epoch_num = 200
global_cuda_available = True
global_output_filename = "out.txt"
global_control_filename = "control.txt"
global_epoch_test_freq = 1

global_obselete_pytorch = False

if global_cuda_available:
    os.environ["CUDA_VISIBLE_DEVICES"] = "2"

global_weight_init_range = (-0.1, 0.0)

In [ ]:
global_resnet_n = 5

nn_arch_start = [16, 3]
nn_arch_weight = (
      [[16, 16, 3, 1]] * global_resnet_n
    + [[16, 32, 3, 2]]
    + [[32, 32, 3, 1]] * (global_resnet_n - 1)
    + [[32, 64, 3, 2]]
    + [[64, 64, 3, 1]] * (global_resnet_n - 1)
)
nn_arch_residue1 = (
      [[16, 16, 1]] * global_resnet_n
    + [[16, 32, 2]]
    + [[32, 32, 1]] * (global_resnet_n - 1)
    + [[32, 64, 2]]
    + [[64, 64, 1]] * (global_resnet_n - 1)
)
nn_arch_residue2 = (
      [None]
    + [[16, 16, 1]] * (global_resnet_n - 1)
    + [None] * 2
    + [[32, 32, 1]] * (global_resnet_n - 2)
    + [None] * 2
    + [[64, 64, 1]] * (global_resnet_n - 2)
)
nn_arch_end = [64, 10]

In [ ]:
transform_train = transforms.Compose([
    transforms.RandomCrop(28, padding=0),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# TODO: Calculate and subtract per-pixel average
transform_test = transforms.Compose([
    transforms.CenterCrop(28),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [ ]:
trainset = torchvision.datasets.CIFAR10(
    root="./data", download=True, train=True, transform=transform_train
)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=global_batch_size, shuffle=True, num_workers=2
)

testset = torchvision.datasets.CIFAR10(
    root="./data", download=True, train=False, transform=transform_test
)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=global_batch_size, shuffle=False, num_workers=2
)

In [ ]:
class StartBlock(nn.Module):
    """First several blocks for resnet
    
    Only contains a single layer of conv2d and a batch norm layer
    """

    def __init__(self, out_planes, kernel_size):
        super(StartBlock, self).__init__()
        self.out_plane = out_planes
        self.kernel_size = kernel_size

        self.conv = nn.Conv2d(
            3, out_planes, kernel_size=kernel_size,
            padding=1, bias=global_conv_bias
        )
        self.bn = nn.BatchNorm2d(out_planes)

    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = functional.relu(out)
        return out

class BasicBlock(nn.Module):
    """Repeated blocks for resnet
    
    Contains two conv layers, two batch norm layers
    """

    def __init__(self, in_planes, out_planes, kernel_size, stride):
        super(BasicBlock, self).__init__()
        self.in_planes = in_planes
        self.out_planes = out_planes
        self.kernel_size = kernel_size
        self.stride = stride

        self.bn1 = nn.BatchNorm2d(in_planes)
        self.conv1 = nn.Conv2d(
            in_planes, out_planes, kernel_size=kernel_size,
            stride=stride, padding=1, bias=global_conv_bias
        )
        self.bn2 = nn.BatchNorm2d(out_planes)
        self.conv2 = nn.Conv2d(
            out_planes, out_planes, kernel_size=kernel_size,
            padding=1, bias=global_conv_bias
        )

    def forward(self, x):
        out = self.bn1(x)
        out = functional.relu(out)
        out = self.conv1(out)
        out = self.bn2(out)
        out = functional.relu(out)
        out = self.conv2(out)
        return out

class ShortcutBlock(nn.Module):
    """Shortcut blocks for resnet
    
    Contains a simple shortcut, which may be a simple identity or
    a 1*1 conv layer.
    """
    
    def __init__(self, in_planes, out_planes, stride):
        super(ShortcutBlock, self).__init__()
        self.in_planes = in_planes
        self.out_planes = out_planes
        self.stride = stride
        if in_planes != out_planes or stride != 1:
            self.shortcut = nn.AvgPool2d(1, stride=stride, padding=0)
        self.pad_size = []
    
    # Note: here check_pad is used to avoid frequent allocation of zeros,
    # but the code here is rather ugly.
    def check_pad(self, size):
        if self.pad_size == size:
            return self.pad
        else:
            if global_cuda_available:
                pad = Variable(torch.zeros(size)).cuda()
            else:
                pad = Variable(torch.zeros(size))
        self.pad = pad
        self.pad_size = size
        return self.pad
    
    def forward(self, x):
        if self.in_planes != self.out_planes or self.stride != 1:
            out = self.shortcut(x)
            size = list(out.size())
            size[1] = self.out_planes - self.in_planes
            out = torch.cat((out, self.check_pad(size)), 1)
            return out
        else:
            return x

class EndBlock(nn.Module):
    """Last several blocks for resnet
    
    Only contains a global average pooling layer and a fully
    connected layer.
    """

    def __init__(self, in_planes, out_classes):
        super(EndBlock, self).__init__()
        self.fc = nn.Linear(in_planes, out_classes)

    def forward(self, x):
        # Note: In the new version of pytorch, let v be a tensor
        # of size (3, 4, 5, 6), then the size of mean of axis 2 is 
        # (3, 4, 6) instead of (3, 4, 1, 6)
        if global_obselete_pytorch:
            out = torch.mean(x, dim=2)
            out = torch.mean(out, dim=3)
            out = out.view(out.size()[0], -1)
        else:
            out = torch.mean(x, dim=2)
            out = torch.mean(out, dim=2)
        out = self.fc(out)
        return out

In [ ]:
class WeightId(nn.Module):
    """A wrapped weight parameter
    
    Packed in order to be attached as attribute of an nn."""
    
    def __init__(self):
        super(WeightId, self).__init__()
        self.weight = nn.Parameter(torch.Tensor(1))
        self.weight.data.uniform_(*global_weight_init_range)
    
    def forward(self):
        return self.weight

In [ ]:
class ResNetWithWeight(nn.Module):
    """General ResNet with weight
    
    Contains BasicBlock, ShortcurBlock and momentum parameters.
    """
    def __init__(
        self, start, weight, residue1, residue2, end
    ):
        super(ResNetWithWeight, self).__init__()
        self.start = start
        self.weight = weight
        self.residue1 = residue1
        self.residue2 = residue2
        self.end = end
        self.start_block = StartBlock(*start)
        self.weight_list = []
        self.residue1_list = []
        self.residue2_list = []
        self.momentum_list = []
        self.link_list = []
        for w, r1, r2 in zip(weight, residue1, residue2):
            rctr = 0
            conv = BasicBlock(*w)
            self.weight_list.append(conv)
            if r1 is not None:
                self.residue1_list.append(ShortcutBlock(*r1))
                rctr += 1
            else:
                self.residue1_list.append(None)
            if r2 is not None:
                self.residue2_list.append(ShortcutBlock(*r2))
                rctr += 2
            else:
                self.residue2_list.append(None)
            self.link_list.append(rctr)
            if rctr == 3:
                self.momentum_list.append(WeightId())
            else:
                self.momentum_list.append(None)
        self.end_block = EndBlock(*end)
        self.weight_block = nn.Sequential(*filter(lambda u: u is not None, self.weight_list))
        self.residue1_block = nn.Sequential(*filter(lambda u: u is not None, self.residue1_list))
        self.residue2_block = nn.Sequential(*filter(lambda u: u is not None, self.residue2_list))
        self.momentum_block = nn.Sequential(*filter(lambda u: u is not None, self.momentum_list))
    
    def forward(self, x):
        out = self.start_block(x)
        out1 = None
        out2 = None
        for w, r1, r2, m, l in zip(
            self.weight_list, self.residue1_list, self.residue2_list, self.momentum_list, self.link_list
        ):
            out2, out1 = out1, out
            if l == 3:
                if global_obselete_pytorch:
                    tout2 = r2(out2)
                    tout1 = r1(out1)
                    out = m().expand_as(tout2) * tout2 + (1. - m()).expand_as(tout1) * tout1 + w(out1)
                else:
                    out = m() * r2(out2) + (1. - m()) * r1(out1) + w(out1)
            elif l == 2:
                out = r2(out2) + w(out1)
            elif l == 1:
                out = r1(out1) + w(out1)
            elif l == 0:
                out = w(out1)
        out = self.end_block(out)
        return out

In [ ]:
net = ResNetWithWeight(nn_arch_start, nn_arch_weight, nn_arch_residue1, nn_arch_residue2, nn_arch_end)

init.msra_init(net)

if global_cuda_available:
    net.cuda()

In [ ]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(
    net.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0001
)

# TODO: Use scheduler to adjust learning rate
def lr_adjust(it):
    if it < 32000:
        return 0.1
    elif it < 48000:
        return 0.01
    elif it < 64000:
        return 0.001
    else:
        return 0.0001

In [ ]:
def train(data, info):
    global net, optimizer, criterion
    inputs, labels = data
    inputs, labels = Variable(inputs), Variable(labels)
    if global_cuda_available:
        inputs, labels = inputs.cuda(), labels.cuda()
    optimizer.zero_grad()
    outputs = net(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    info[0] = loss.data[0]
    info[1] = labels.size()[0]

In [ ]:
def test(info):
    global net
    correct_sum = 0
    total_loss_sum = 0.
    total_ctr = 0
    for data in testloader:
        inputs, labels = data
        inputs, labels = Variable(inputs), Variable(labels)
        if global_cuda_available:
            inputs, labels = inputs.cuda(), labels.cuda()
        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total_ctr += labels.size()[0]
        correct_sum += (predicted == labels.data).sum()
        loss = criterion(outputs, labels)
        total_loss_sum += loss.data[0]
    info[0] = correct_sum
    info[1] = total_ctr
    info[2] = total_loss_sum

In [ ]:
write_file_and_close(global_output_filename, "Cleaning...", flag = "w")
write_file_and_close(
    global_output_filename,
    "The length of trainloader and testloader is {:d} and {:d} resp."
    .format(len(trainloader), len(testloader))
)

write_file_and_close(global_output_filename, "Start training")

In [ ]:
it = 0
for epoch in range(global_epoch_num):
    if not check_control(global_control_filename):
        write_file_and_close(gloabl_output_filename, "Control lost")
    running_loss_sum = 0.
    total_loss_sum = 0.
    ctr_sum = 0
    total_ctr = 0
    for g in optimizer.param_groups:
        g["lr"] = lr_adjust(it)
    for i, data in enumerate(trainloader):
        info = [0., 0]
        train(data, info)
        running_loss_sum += info[0]
        total_loss_sum += info[0]
        ctr_sum += 1
        total_ctr += info[1]
        if (i + 1) % global_data_print_freq == 0:
            write_file_and_close(global_output_filename,
                "epoch: {:d}, "
                "train set index: {:d}, "
                "average loss: {:.10f}"
                .format(epoch, i, running_loss_sum / ctr_sum)
            )
            running_loss_sum = 0.0
            ctr_sum = 0
        it = it + 1
        if (i + 1) % global_data_print_freq == 0:
            s = ""
            for m in net.momentum_list:
                if m is not None:
                    s += "{:.10f}, ".format(float(m.weight.data.cpu().numpy()))
            write_file_and_close(global_output_filename,
                "momentum parameters: {:s}".format(s[:-2])
            )
    write_file_and_close(global_output_filename,
        "Epoch {:d} finished, average loss: {:.10f}"
        .format(epoch, total_loss_sum / total_ctr)
    )
    if (epoch + 1) % global_epoch_test_freq == 0:
        write_file_and_close(global_output_filename,
                             "Starting testing"
        )
        info = [0., 0., 0.]
        test(info)
        write_file_and_close(global_output_filename,
            "Correct: {:d}, total: {:d}, "
            "accuracy: {:.10f}, average loss: {:.10f}"
            .format(
                info[0], info[1], info[0] / info[1], info[2] / info[1]
            )
        )
        write_file_and_close(global_output_filename, "Finished testing")

# TODO: Modify the filename
model_filename = generate_filename()
torch.save(net, model_filename)

In [ ]:
s = numpy.zeros((3, 28, 28))
for i, d in enumerate(trainloader):
    d, _ = d
    if i == 0:
        print(d[2])
    s += d.numpy().sum(axis=0)
s /= 50000

In [ ]: