Model Playground

Various tests and small experiments on toy networks.


In [22]:
from imp import reload
import nupic.research.frameworks.dynamic_sparse.networks.layers as layers
reload(layers);
import nupic.research.frameworks.dynamic_sparse.networks.layers as networks
reload(networks);

In [59]:
from collections import OrderedDict

import numpy as np
import torch
from torchvision import models
from nupic.research.frameworks.dynamic_sparse.networks.layers import DSConv2d
from nupic.torch.models.sparse_cnn import gsc_sparse_cnn, gsc_super_sparse_cnn, GSCSparseCNN, MNISTSparseCNN
from nupic.research.frameworks.dynamic_sparse.networks import mnist_sparse_dsnn, GSCSparseFullCNN, gsc_sparse_dsnn_fullyconv
from torchsummary import summary

from torchviz import make_dot

Load Models


In [61]:
# resnet18 = models.resnet18()
resnet50 = models.resnet50()

alexnet = models.alexnet()
# mnist_scnn = MNISTSparseCNN()
gsc_scnn = GSCSparseCNN()
# dscnn = mnist_sparse_dscnn({})
# gscf = gsc_sparse_dscnn_fullyconv({'prune_methods': ["none", "static"]}) # GSCSparseFullCNN(cnn_out_channels=(32, 64, 1))

In [541]:
# resnet18
# resnet18
# mnist_scnn
gsc_scnn
# dscnn
# gscf


Out[541]:
GSCSparseCNN(
  (cnn1): Conv2d(1, 64, kernel_size=(5, 5), stride=(1, 1))
  (cnn1_batchnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
  (cnn1_maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (cnn1_kwinner): KWinners2d(channels=64, n=0, percent_on=0.095, boost_strength=1.5, boost_strength_factor=0.9, k_inference_factor=1.5, duty_cycle_period=1000)
  (cnn2): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1))
  (cnn2_batchnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
  (cnn2_maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (cnn2_kwinner): KWinners2d(channels=64, n=0, percent_on=0.125, boost_strength=1.5, boost_strength_factor=0.9, k_inference_factor=1.5, duty_cycle_period=1000)
  (flatten): Flatten()
  (linear): SparseWeights(
    weight_sparsity=0.4
    (module): Linear(in_features=1600, out_features=1000, bias=True)
  )
  (linear_bn): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
  (linear_kwinner): KWinners(n=1000, percent_on=0.1, boost_strength=1.5, boost_strength_factor=0.9, k_inference_factor=1.5, duty_cycle_period=1000)
  (output): Linear(in_features=1000, out_features=12, bias=True)
  (softmax): LogSoftmax()
)

In [412]:
inp = torch.rand(2, 1, 32, 32)
gsc_scnn(inp).shape
gscf(inp).shape

summary(gscf, input_size=(1, 32, 32))


hi <class 'nupic.torch.models.sparse_cnn.GSCSparseCNN'>
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 32, 28, 28]             832
       BatchNorm2d-2           [-1, 32, 28, 28]               0
        KWinners2d-3           [-1, 32, 28, 28]               0
         MaxPool2d-4           [-1, 32, 14, 14]               0
      SparseConv2d-5           [-1, 64, 10, 10]          51,264
       BatchNorm2d-6           [-1, 64, 10, 10]               0
        KWinners2d-7           [-1, 64, 10, 10]               0
         MaxPool2d-8             [-1, 64, 5, 5]               0
           Flatten-9                 [-1, 1600]               0
           Linear-10                   [-1, 12]          19,212
       LogSoftmax-11                   [-1, 12]               0
================================================================
Total params: 71,308
Trainable params: 71,308
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.79
Params size (MB): 0.27
Estimated Total Size (MB): 1.07
----------------------------------------------------------------

Fun with sequentials.


In [336]:
sq0 = torch.nn.Sequential(OrderedDict([('sq1', torch.nn.Sequential(OrderedDict([('cnn1', torch.nn.Conv2d(3, 3, 3))])) )]))
sq1 = torch.nn.Sequential(od)
sq2 = torch.nn.Sequential(torch.nn.Sequential(od), torch.nn.Conv2d(3, 3, 3))
sq3 = torch.nn.Sequential(OrderedDict([('sq1', sq1), ('sq2', sq2)]))
sq4 = torch.nn.Sequential(sq3)

In [339]:
for n, m in sq4.named_modules():
    ns = n.split('.')
    print([n_.isdigit() for n_ in ns])
    print('name')
    print(n, m)   

# for n, m in sq2._modules.items():
#     print(n, m)


[False]
name
 Sequential(
  (0): Sequential(
    (sq1): Sequential(
      (cnn1): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
    )
    (sq2): Sequential(
      (0): Sequential(
        (cnn1): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
      )
      (1): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
    )
  )
)
[True]
name
0 Sequential(
  (sq1): Sequential(
    (cnn1): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
  )
  (sq2): Sequential(
    (0): Sequential(
      (cnn1): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
    )
    (1): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
  )
)
[True, False]
name
0.sq1 Sequential(
  (cnn1): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
)
[True, False, False]
name
0.sq1.cnn1 Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
[True, False]
name
0.sq2 Sequential(
  (0): Sequential(
    (cnn1): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
  )
  (1): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
)
[True, False, True]
name
0.sq2.0 Sequential(
  (cnn1): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
)
[True, False, True]
name
0.sq2.1 Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))

Fun with grads


In [114]:
v1 = torch.tensor([0., 0., 0.], requires_grad=True)
v2 = torch.tensor([1., 2., 3.], requires_grad=True)
v3 = torch.tensor([5.], requires_grad=True)
v4 = (v1.sum() + v2.sum()) / v3
h = v3.register_hook(lambda grad: grad * 1.5)  # double the gradient

v4.backward(torch.tensor([1.]))
v1.grad, v2.grad, v3.grad


Out[114]:
(tensor([0.2000, 0.2000, 0.2000]),
 tensor([0.2000, 0.2000, 0.2000]),
 tensor([-0.3600]))

In [118]:
v1 = torch.tensor([1., 4., 1.], requires_grad=True)
v2 = torch.tensor([1., 2., 3.], requires_grad=True)
v3 = torch.tensor([5.], requires_grad=True)
v4 = (v1.sum() + v2.sum()) / v3
h = v3.register_hook(lambda grad: grad * 3.0)  # double the gradient

v4.backward(torch.tensor([1.]))
v1.grad, v2.grad, v3.grad


Out[118]:
(tensor([0.2000, 0.2000, 0.2000]),
 tensor([0.2000, 0.2000, 0.2000]),
 tensor([-1.4400]))

Wide RESNET


In [262]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class BasicBlock(nn.Module):
    def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
        super(BasicBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_planes)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.droprate = dropRate
        self.equalInOut = (in_planes == out_planes)
        self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
                               padding=0, bias=False) or None
    def forward(self, x):
        if not self.equalInOut:
            x = self.relu1(self.bn1(x))
        else:
            out = self.relu1(self.bn1(x))
        out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
        if self.droprate > 0:
            out = F.dropout(out, p=self.droprate, training=self.training)
        out = self.conv2(out)
        return torch.add(x if self.equalInOut else self.convShortcut(x), out)

class NetworkBlock(nn.Module):
    def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
        super(NetworkBlock, self).__init__()
        self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
    def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
        layers = []
        for i in range(int(nb_layers)):
            layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
        return nn.Sequential(*layers)
    def forward(self, x):
        return self.layer(x)

class WideResNet(nn.Module):
    def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0):
        super(WideResNet, self).__init__()
        nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
        assert((depth - 4) % 6 == 0)
        n = (depth - 4) / 6
        block = BasicBlock
        # 1st conv before any network block
        self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
                               padding=1, bias=False)
        # 1st block
        self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
        # 2nd block
        self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
        # 3rd block
        self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
        # global average pooling and classifier
        self.bn1 = nn.BatchNorm2d(nChannels[3])
        self.relu = nn.ReLU(inplace=True)
        self.fc = nn.Linear(nChannels[3], num_classes)
        self.nChannels = nChannels[3]

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()
    def forward(self, x):
        out = self.conv1(x)
        out = self.block1(out)
        out = self.block2(out)
        out = self.block3(out)
        out = self.relu(self.bn1(out))
        out = F.avg_pool2d(out, 8)
        out = out.view(-1, self.nChannels)
        return self.fc(out)

In [267]:
WideResNet(16, 10)


Out[267]:
WideResNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (block1): NetworkBlock(
    (layer): Sequential(
      (0): BasicBlock(
        (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace)
        (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace)
        (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (1): BasicBlock(
        (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace)
        (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace)
        (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
  )
  (block2): NetworkBlock(
    (layer): Sequential(
      (0): BasicBlock(
        (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace)
        (conv1): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace)
        (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (convShortcut): Conv2d(16, 32, kernel_size=(1, 1), stride=(2, 2), bias=False)
      )
      (1): BasicBlock(
        (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace)
        (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace)
        (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
  )
  (block3): NetworkBlock(
    (layer): Sequential(
      (0): BasicBlock(
        (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace)
        (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (convShortcut): Conv2d(32, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)
      )
      (1): BasicBlock(
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace)
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
  )
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace)
  (fc): Linear(in_features=64, out_features=10, bias=True)
)

Fun with Learning Rates and Decays


In [586]:
import torch
import numpy as np

np.random.seed(123)
np.set_printoptions(8, suppress=True)

x_numpy = np.random.random((3, 4)).astype(np.double)
x_torch = torch.tensor(x_numpy, requires_grad=True)
x_torch2 = torch.tensor(x_numpy, requires_grad=True)

w_numpy = np.random.random((4, 5)).astype(np.double)
w_torch = torch.tensor(w_numpy, requires_grad=True)
w_torch2 = torch.tensor(w_numpy, requires_grad=True)

def log_grad(grad):
    print(grad)
    
w_torch.register_hook(log_grad)
w_torch2.register_hook(log_grad)

lr = 0.00001
weight_decay = 0.9
sgd = torch.optim.SGD([w_torch], lr=lr, weight_decay=0)
sgd2 = torch.optim.SGD([w_torch2], lr=lr, weight_decay=weight_decay)

y_torch = torch.matmul(x_torch, w_torch)
y_torch2 = torch.matmul(x_torch2, w_torch2)

loss = y_torch.sum()
loss2 = y_torch2.sum()

sgd.zero_grad()
sgd2.zero_grad()

loss.backward()
loss2.backward()

sgd.step()
sgd2.step()

w_grad = w_torch.grad.data.numpy()
w_grad2 = w_torch2.grad.data.numpy()

print("check_grad")
print(w_grad)
print(w_grad2 - weight_decay * w_numpy)


tensor([[1.8969, 1.8969, 1.8969, 1.8969, 1.8969],
        [1.1014, 1.1014, 1.1014, 1.1014, 1.1014],
        [1.5508, 1.5508, 1.5508, 1.5508, 1.5508],
        [1.9652, 1.9652, 1.9652, 1.9652, 1.9652]], dtype=torch.float64)
tensor([[1.8969, 1.8969, 1.8969, 1.8969, 1.8969],
        [1.1014, 1.1014, 1.1014, 1.1014, 1.1014],
        [1.5508, 1.5508, 1.5508, 1.5508, 1.5508],
        [1.9652, 1.9652, 1.9652, 1.9652, 1.9652]], dtype=torch.float64)
check_grad
[[1.89687006 1.89687006 1.89687006 1.89687006 1.89687006]
 [1.10136331 1.10136331 1.10136331 1.10136331 1.10136331]
 [1.55079367 1.55079367 1.55079367 1.55079367 1.55079367]
 [1.96519422 1.96519422 1.96519422 1.96519422 1.96519422]]
[[1.89687006 1.89687006 1.89687006 1.89687006 1.89687006]
 [1.10136331 1.10136331 1.10136331 1.10136331 1.10136331]
 [1.55079367 1.55079367 1.55079367 1.55079367 1.55079367]
 [1.96519422 1.96519422 1.96519422 1.96519422 1.96519422]]

More fun with Gradients


In [1]:
import torch

# -----------------
# Helper function
# -----------------
def shape(t):
    if isinstance(t, tuple):
        return tuple(t_.shape if t_ is not None else None for t_ in t)
    else:
        return t.shape

# -----------------
# Grad hooks
# -----------------

# Zeros grad for weights
def w_hook(grad):
    print(' '*8, 'w-grad shape = ', shape(grad))
    grad[:] = 0
    return grad

# No change for biases.
def b_hook(grad):
    print(' '*8, 'b-grad shape = ', shape(grad))
    return grad

# -----------------------
# Test layers with biases
# -----------------------

# The following should confirm whether non-zero biases with non-zero gradient flows
# yield changes to the the weights of the layer - indepent of those weights' grad flows.  

layer1 = torch.nn.Conv2d(3, 3, 3)
layer2 = torch.nn.Linear(10, 100)
in1 = torch.rand(10, 3, 10, 10)
in2 = torch.rand(10, 10, 10)

for layer, input_ in [(layer1, in1), (layer2, in2)]:
    
    print('-------', layer.__class__.__name__, '--------\n')
    layer.weight.register_hook(w_hook)
    layer.bias.register_hook(b_hook)

    optim = torch.optim.SGD(layer.parameters(), lr=0.01)
    
    # Sets all weights and biases to 1.
    with torch.no_grad():
        layer.weight.data[:] = 1
        layer.bias.data[:] = 1

    optim.zero_grad()
    o = layer(input_)
    loss = o.mean()
    
    print(' '*5, 'Computing grads...')
    loss.backward()
    optim.step()
    
    # See if weights and biases are still 1.
    # This should only be the case for the weights
    # as we zeroed their gradients.
    print()
    print(' '*5, 'Checking results...')
    print(' '*8, 'Optimized weight - All close to 1:', (layer.weight == 1).all())
    print(' '*8, 'Optimized Bias - All close to 1:', (layer.bias == 1).all())
    print()


------- Conv2d --------

      Computing grads...
         w-grad shape =  torch.Size([3, 3, 3, 3])
         b-grad shape =  torch.Size([3])

      Checking results...
         Optimized weight - All close to 1: tensor(True)
         Optimized Bias - All close to 1: tensor(False)

------- Linear --------

      Computing grads...
         b-grad shape =  torch.Size([100])
         w-grad shape =  torch.Size([100, 10])

      Checking results...
         Optimized weight - All close to 1: tensor(True)
         Optimized Bias - All close to 1: tensor(False)

Fun with Datasets and Dataloaders.


In [40]:
from torchvision import datasets, transforms

In [41]:
import os

In [78]:
data_dir = "~/nta/datasets"
data_dir = os.path.expanduser(data_dir)
if os.path.exists(data_dir):
    dataset = getattr(datasets, "CIFAR10")(
                    root=data_dir, train=True, transform=[transforms.ToTensor()]
                )
else:
    print("Couldn't find path {}".format(data_dir))

In [79]:
tl = torch.utils.data.DataLoader(
    dataset, batch_size=4, shuffle=True
)

In [80]:
class C(torch.nn.Conv2d):
    def to(self, device, *args, **kwargs):
        print(device, args)
        super().to(device, *args, **kwargs)

Fun with iterators


In [51]:
class L(list):
    def __iter__(self):
        elem = super().__iter__()
        print('elem', elem)
        return elem

In [52]:
l = L([1, 2, 3])

In [54]:
for l_ in l:
    print(l_)


elem <list_iterator object at 0x11d864438>
1
2
3

Small Dense GSC nets


In [101]:
np.set_printoptions(suppress=True)
params = np.array([1600, 102400, 1600000, 12000])

In [188]:
def _get_gsc_small_dense_params(on_perc, verbose=False):
    
    def vprint(*args):
        if verbose:
            print(*args)
    
    # Define number of params in dense GSC.
    # TODO: make configurable based off orignal `cnn_out_channels` and `linear_units`
    # default_config = dict(
    #     cnn_out_channels=(64, 64),
    #     linear_units=1000,
    # )
    large_dense_params = np.array([1600, 102400, 1600000, 12000])
    
    # Cacluate num params in large sparse GSC.
    large_sparse_params = large_dense_params * on_perc
    
    # Caclculate number of params in small dense.
    kernel_adjustment_factor = np.array([25, 25, 1, 1])
    small_sparse_params = large_sparse_params / kernel_adjustment_factor
    
    # Init desired congfig.
    cnn_out_channels = np.array([0, 0])
    linear_units = None

    # Assume 1 channel input to first conv
    cnn_out_channels[0] = large_sparse_params[0] / 25
    cnn_out_channels[1] = large_sparse_params[1] / (cnn_out_channels[0] * 25)
    cnn_out_channels = np.round(cnn_out_channels).astype(np.int)
    linear_units = large_sparse_params[2] / (25 * cnn_out_channels[1])
    linear_units = int(np.round(linear_units))

    # Simulate foward pass for sanity check
    conv1 = torch.nn.Conv2d(1, cnn_out_channels[0], 5)
    maxp1 = torch.nn.MaxPool2d(2)
    conv2 = torch.nn.Conv2d(cnn_out_channels[0], cnn_out_channels[1], 5)
    maxp2 = torch.nn.MaxPool2d(2)
    flat = torch.nn.Flatten()
    lin1 = torch.nn.Linear(25 * cnn_out_channels[1], linear_units)
    lin2 = torch.nn.Linear(linear_units, 12)
    
    x = torch.rand(10, 1, 32, 32)
    x = conv1(x)
    x = maxp1(x)
    x = conv2(x)
    x = maxp2(x)
    x = flat(x)
    x = lin1(x)
    x = lin2(x)
    
    # Calculate number of params.
    new_params = {
        "conv_1": np.prod(conv1.weight.shape),
        "conv_2": np.prod(conv2.weight.shape),
        "lin1": np.prod(lin1.weight.shape),
        "lin2": np.prod(lin2.weight.shape),
    }

    # Compare with desired.
    total_new = 0
    total_old = 0
    for p_old, (layer, p_new) in zip(large_sparse_params, new_params.items()):
        abs_diff = p_new - p_old
        rel_diff = abs_diff / float(p_old)
        vprint('---- {} -----'.format(layer))
        vprint('   new - ', p_new)
        vprint('   old - ', p_old)
        vprint('   abs diff:', abs_diff)
        vprint('   rel diff: {}% change'.format(100 * rel_diff))
        vprint()
        total_new += p_new
        total_old += p_old
    
    total_abs_diff = total_new - total_old
    total_rel_diff = total_abs_diff / float(total_old)
    vprint('---- Summary ----')
    vprint('   total new - ', total_new)
    vprint('   total old - ', total_old)
    vprint('   total abs diff:', total_abs_diff)
    vprint('   total rel diff: {}% change'.format(100 * total_rel_diff))
    
    # New config
    new_config = dict(
        cnn_out_channels=tuple(cnn_out_channels),
        linear_units=linear_units,
    )
    return new_config

for perc in [0.02, 0.04, 0.06, 0.08, 0.10]:
    c = _get_gsc_small_dense_params(perc, verbose=True)
    net = GSCSparseCNN(**c)


---- conv_1 -----
   new -  25
   old -  32.0
   abs diff: -7.0
   rel diff: -21.875% change

---- conv_2 -----
   new -  2025
   old -  2048.0
   abs diff: -23.0
   rel diff: -1.123046875% change

---- lin1 -----
   new -  32400
   old -  32000.0
   abs diff: 400.0
   rel diff: 1.25% change

---- lin2 -----
   new -  192
   old -  240.0
   abs diff: -48.0
   rel diff: -20.0% change

---- Summary ----
   total new -  34642
   total old -  34320.0
   total abs diff: 322.0
   total rel diff: 0.9382284382284383% change
---- conv_1 -----
   new -  50
   old -  64.0
   abs diff: -14.0
   rel diff: -21.875% change

---- conv_2 -----
   new -  4050
   old -  4096.0
   abs diff: -46.0
   rel diff: -1.123046875% change

---- lin1 -----
   new -  64800
   old -  64000.0
   abs diff: 800.0
   rel diff: 1.25% change

---- lin2 -----
   new -  384
   old -  480.0
   abs diff: -96.0
   rel diff: -20.0% change

---- Summary ----
   total new -  69284
   total old -  68640.0
   total abs diff: 644.0
   total rel diff: 0.9382284382284383% change
---- conv_1 -----
   new -  75
   old -  96.0
   abs diff: -21.0
   rel diff: -21.875% change

---- conv_2 -----
   new -  6075
   old -  6144.0
   abs diff: -69.0
   rel diff: -1.123046875% change

---- lin1 -----
   new -  95175
   old -  96000.0
   abs diff: -825.0
   rel diff: -0.8593750000000001% change

---- lin2 -----
   new -  564
   old -  720.0
   abs diff: -156.0
   rel diff: -21.666666666666668% change

---- Summary ----
   total new -  101889
   total old -  102960.0
   total abs diff: -1071.0
   total rel diff: -1.0402097902097902% change
---- conv_1 -----
   new -  125
   old -  128.0
   abs diff: -3.0
   rel diff: -2.34375% change

---- conv_2 -----
   new -  8125
   old -  8192.0
   abs diff: -67.0
   rel diff: -0.81787109375% change

---- lin1 -----
   new -  128375
   old -  128000.0
   abs diff: 375.0
   rel diff: 0.29296875% change

---- lin2 -----
   new -  948
   old -  960.0
   abs diff: -12.0
   rel diff: -1.25% change

---- Summary ----
   total new -  137573
   total old -  137280.0
   total abs diff: 293.0
   total rel diff: 0.21343240093240093% change
---- conv_1 -----
   new -  150
   old -  160.0
   abs diff: -10.0
   rel diff: -6.25% change

---- conv_2 -----
   new -  10200
   old -  10240.0
   abs diff: -40.0
   rel diff: -0.390625% change

---- lin1 -----
   new -  159800
   old -  160000.0
   abs diff: -200.0
   rel diff: -0.125% change

---- lin2 -----
   new -  1128
   old -  1200.0
   abs diff: -72.0
   rel diff: -6.0% change

---- Summary ----
   total new -  171278
   total old -  171600.0
   total abs diff: -322.0
   total rel diff: -0.18764568764568765% change

In [184]:
net


Out[184]:
GSCSparseCNN(
  (cnn1): Conv2d(1, 1, kernel_size=(5, 5), stride=(1, 1))
  (cnn1_batchnorm): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
  (cnn1_maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (cnn1_kwinner): KWinners2d(channels=1, n=0, percent_on=0.095, boost_strength=1.6699999570846558, boost_strength_factor=0.9, k_inference_factor=1.5, duty_cycle_period=1000)
  (cnn2): Conv2d(1, 81, kernel_size=(5, 5), stride=(1, 1))
  (cnn2_batchnorm): BatchNorm2d(81, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
  (cnn2_maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (cnn2_kwinner): KWinners2d(channels=81, n=0, percent_on=0.125, boost_strength=1.6699999570846558, boost_strength_factor=0.9, k_inference_factor=1.5, duty_cycle_period=1000)
  (flatten): Flatten()
  (linear): SparseWeights(
    weight_sparsity=0.4
    (module): Linear(in_features=2025, out_features=16, bias=True)
  )
  (linear_bn): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
  (linear_kwinner): KWinners(n=16, percent_on=0.1, boost_strength=1.6699999570846558, boost_strength_factor=0.9, k_inference_factor=1.5, duty_cycle_period=1000)
  (output): Linear(in_features=16, out_features=12, bias=True)
  (softmax): LogSoftmax()
)

In [117]:
for perc in [0.02, 0.04, 0.06, 0.08, 0.10]:
    old_params = params * perc
    display(old_params) 
    factor = np.array([25, 25, 1, 1])
    new_params = old_params / factor
    new_params = np.round(new_params).astype(np.int) * factor
    display(new_params)
    
    diff = new_params - (old_params)
    print(diff)
    print(np.sum(diff))
    print(np.sum(diff) / np.sum(old_params))

    
    print()


array([   32.,  2048., 32000.,   240.])
array([   25,  2050, 32000,   240])
[-7.  2.  0.  0.]
-5.0
-0.0001456876456876457

array([   64.,  4096., 64000.,   480.])
array([   75,  4100, 64000,   480])
[11.  4.  0.  0.]
15.0
0.00021853146853146853

array([   96.,  6144., 96000.,   720.])
array([  100,  6150, 96000,   720])
[4. 6. 0. 0.]
10.0
9.712509712509713e-05

array([   128.,   8192., 128000.,    960.])
array([   125,   8200, 128000,    960])
[-3.  8.  0.  0.]
5.0
3.642191142191142e-05

array([   160.,  10240., 160000.,   1200.])
array([   150,  10250, 160000,   1200])
[-10.  10.   0.   0.]
0.0
0.0

Fun with Resnets


In [62]:
resnet50


Out[62]:
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): Bottleneck(
      (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (2): Bottleneck(
      (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )
  (layer2): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): Bottleneck(
      (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (2): Bottleneck(
      (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (3): Bottleneck(
      (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )
  (layer3): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): Bottleneck(
      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (2): Bottleneck(
      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (3): Bottleneck(
      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (4): Bottleneck(
      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (5): Bottleneck(
      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )
  (layer4): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): Bottleneck(
      (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (2): Bottleneck(
      (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=2048, out_features=1000, bias=True)
)

In [65]:
# Skip to content
# Search or jump to…

# Pull requests
# Issues
# Marketplace
# Explore
 
# @mvacaporale 
# Learn Git and GitHub without any code!
# Using the Hello World guide, you’ll start a branch, write comments, and open a pull request.


# Code Issues 40 Pull requests 14 Projects 0 Wiki Security Pulse Community
# DeepLearningExamples/PyTorch/Classification/RN50v1.5/image_classification/resnet.py
# @nvpstr nvpstr Updating PyTorch models
# 5eaebef on May 27
# 272 lines (221 sloc)  8.15 KB
  
import math
import torch
import torch.nn as nn
import numpy as np

__all__ = ['ResNet', 'build_resnet', 'resnet_versions', 'resnet_configs']

# ResNetBuilder {{{

class ResNetBuilder(object):
    def __init__(self, version, config):
        self.config = config

        self.L = sum(version['layers'])
        self.M = version['block'].M

    def conv(self, kernel_size, in_planes, out_planes, stride=1):
        if kernel_size == 3:
            conv = self.config['conv'](
                    in_planes, out_planes, kernel_size=3, stride=stride,
                    padding=1, bias=False)
        elif kernel_size == 1:
            conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
                             bias=False)
        elif kernel_size == 5:
            conv = nn.Conv2d(in_planes, out_planes, kernel_size=5, stride=stride,
                             padding=2, bias=False)
        elif kernel_size == 7:
            conv = nn.Conv2d(in_planes, out_planes, kernel_size=7, stride=stride,
                             padding=3, bias=False)
        else:
            return None

        if self.config['nonlinearity'] == 'relu':
            nn.init.kaiming_normal_(conv.weight,
                    mode=self.config['conv_init'],
                    nonlinearity=self.config['nonlinearity'])

        return conv

    def conv3x3(self, in_planes, out_planes, stride=1):
        """3x3 convolution with padding"""
        c = self.conv(3, in_planes, out_planes, stride=stride)
        return c

    def conv1x1(self, in_planes, out_planes, stride=1):
        """1x1 convolution with padding"""
        c = self.conv(1, in_planes, out_planes, stride=stride)
        return c

    def conv7x7(self, in_planes, out_planes, stride=1):
        """7x7 convolution with padding"""
        c = self.conv(7, in_planes, out_planes, stride=stride)
        return c

    def conv5x5(self, in_planes, out_planes, stride=1):
        """5x5 convolution with padding"""
        c = self.conv(5, in_planes, out_planes, stride=stride)
        return c

    def batchnorm(self, planes, last_bn=False):
        bn = nn.BatchNorm2d(planes)
        gamma_init_val = 0 if last_bn and self.config['last_bn_0_init'] else 1
        nn.init.constant_(bn.weight, gamma_init_val)
        nn.init.constant_(bn.bias, 0)

        return bn

    def activation(self):
        return self.config['activation']()

# ResNetBuilder }}}

# BasicBlock {{{
class BasicBlock(nn.Module):
    M = 2
    expansion = 1

    def __init__(self, builder, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = builder.conv3x3(inplanes, planes, stride)
        self.bn1 = builder.batchnorm(planes)
        self.relu = builder.activation()
        self.conv2 = builder.conv3x3(planes, planes)
        self.bn2 = builder.batchnorm(planes, last_bn=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        if self.bn1 is not None:
            out = self.bn1(out)

        out = self.relu(out)

        out = self.conv2(out)

        if self.bn2 is not None:
            out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out
# BasicBlock }}}

# Bottleneck {{{
class Bottleneck(nn.Module):
    M = 3
    expansion = 4

    def __init__(self, builder, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = builder.conv1x1(inplanes, planes)
        self.bn1 = builder.batchnorm(planes)
        self.conv2 = builder.conv3x3(planes, planes, stride=stride)
        self.bn2 = builder.batchnorm(planes)
        self.conv3 = builder.conv1x1(planes, planes * self.expansion)
        self.bn3 = builder.batchnorm(planes * self.expansion, last_bn=True)
        self.relu = builder.activation()
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual

        out = self.relu(out)

        return out
# Bottleneck }}}

# ResNet {{{
class ResNet(nn.Module):
    def __init__(self, builder, block, layers, num_classes=1000):
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = builder.conv7x7(3, 64, stride=2)
        self.bn1 = builder.batchnorm(64)
        self.relu = builder.activation()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(builder, block, 64, layers[0])
        self.layer2 = self._make_layer(builder, block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(builder, block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(builder, block, 512, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, builder, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            dconv = builder.conv1x1(self.inplanes, planes * block.expansion,
                                    stride=stride)
            dbn = builder.batchnorm(planes * block.expansion)
            if dbn is not None:
                downsample = nn.Sequential(dconv, dbn)
            else:
                downsample = dconv

        layers = []
        layers.append(block(builder, self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(builder, self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        if self.bn1 is not None:
            x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x
# ResNet }}}


resnet_configs = {
        'classic' : {
            'conv' : nn.Conv2d,
            'conv_init' : 'fan_out',
            'nonlinearity' : 'relu',
            'last_bn_0_init' : False,
            'activation' : lambda: nn.ReLU(inplace=True),
            },
        'fanin' : {
            'conv' : nn.Conv2d,
            'conv_init' : 'fan_in',
            'nonlinearity' : 'relu',
            'last_bn_0_init' : False,
            'activation' : lambda: nn.ReLU(inplace=True),
            },
        }

resnet_versions = {
        'resnet18' : {
            'net' : ResNet,
            'block' : BasicBlock,
            'layers' : [2, 2, 2, 2],
            'num_classes' : 1000,
            },
         'resnet34' : {
            'net' : ResNet,
            'block' : BasicBlock,
            'layers' : [3, 4, 6, 3],
            'num_classes' : 1000,
            },
         'resnet50' : {
            'net' : ResNet,
            'block' : Bottleneck,
            'layers' : [3, 4, 6, 3],
            'num_classes' : 1000,
            },
        'resnet101' : {
            'net' : ResNet,
            'block' : Bottleneck,
            'layers' : [3, 4, 23, 3],
            'num_classes' : 1000,
            },
        'resnet152' : {
            'net' : ResNet,
            'block' : Bottleneck,
            'layers' : [3, 8, 36, 3],
            'num_classes' : 1000,
            },
        }


def build_resnet(version, config, model_state=None):
    version = resnet_versions[version]
    config = resnet_configs[config]

    builder = ResNetBuilder(version, config)
    print("Version: {}".format(version))
    print("Config: {}".format(config))
    model = version['net'](builder,
                           version['block'],
                           version['layers'],
                           version['num_classes'])

    return model

In [79]:
i = torch.rand(10, 3, 224, 224)
make_dot(resnet50(i).mean()) #, params=dict(resnet50.named_parameters()))


Out[79]:
%3 5660023888 MeanBackward0 5660024784 AddmmBackward 5660024784->5660023888 5660122080 (1000) 5660122080->5660024784 5660023664 ViewBackward 5660023664->5660024784 5660022264 ViewBackward 5660022264->5660023664 5660021536 MeanBackward1 5660021536->5660022264 5660021424 ViewBackward 5660021424->5660021536 5660021144 ReluBackward1 5660021144->5660021424 5036560456 AddBackward0 5036560456->5660021144 4900957208 NativeBatchNormBackward 4900957208->5036560456 5741732080 MkldnnConvolutionBackward 5741732080->4900957208 5741732472 ReluBackward1 5741732472->5741732080 5741732752 NativeBatchNormBackward 5741732752->5741732472 5741733200 MkldnnConvolutionBackward 5741733200->5741732752 5741733480 ReluBackward1 5741733480->5741733200 5741733928 NativeBatchNormBackward 5741733928->5741733480 5741734320 MkldnnConvolutionBackward 5741734320->5741733928 4900956872 ReluBackward1 4900956872->5036560456 4900956872->5741734320 5741734768 AddBackward0 5741734768->4900956872 5741735160 NativeBatchNormBackward 5741735160->5741734768 5741735664 MkldnnConvolutionBackward 5741735664->5741735160 5741735888 ReluBackward1 5741735888->5741735664 5741740440 NativeBatchNormBackward 5741740440->5741735888 5741740944 MkldnnConvolutionBackward 5741740944->5741740440 5741740720 ReluBackward1 5741740720->5741740944 5741741560 NativeBatchNormBackward 5741741560->5741740720 5741741840 MkldnnConvolutionBackward 5741741840->5741741560 5741735272 ReluBackward1 5741735272->5741734768 5741735272->5741741840 5741743072 AddBackward0 5741743072->5741735272 5741743520 NativeBatchNormBackward 5741743520->5741743072 5741743352 MkldnnConvolutionBackward 5741743352->5741743520 5741744024 ReluBackward1 5741744024->5741743352 5741748576 NativeBatchNormBackward 5741748576->5741744024 5741749024 MkldnnConvolutionBackward 5741749024->5741748576 5741749304 ReluBackward1 5741749304->5741749024 5741749136 NativeBatchNormBackward 5741749136->5741749304 5741750032 MkldnnConvolutionBackward 5741750032->5741749136 5741749808 ReluBackward1 5741749808->5741750032 5741743912 MkldnnConvolutionBackward 5741749808->5741743912 5741750648 AddBackward0 5741750648->5741749808 5741750928 NativeBatchNormBackward 5741750928->5741750648 5741751320 MkldnnConvolutionBackward 5741751320->5741750928 5741751824 ReluBackward1 5741751824->5741751320 5741752160 NativeBatchNormBackward 5741752160->5741751824 5741751936 MkldnnConvolutionBackward 5741751936->5741752160 5741753624 ReluBackward1 5741753624->5741751936 5741754016 NativeBatchNormBackward 5741754016->5741753624 5741754520 MkldnnConvolutionBackward 5741754520->5741754016 5741751152 ReluBackward1 5741751152->5741750648 5741751152->5741754520 5741754296 AddBackward0 5741754296->5741751152 5741755304 NativeBatchNormBackward 5741755304->5741754296 5741755864 MkldnnConvolutionBackward 5741755864->5741755304 5741755584 ReluBackward1 5741755584->5741755864 5741756088 NativeBatchNormBackward 5741756088->5741755584 5741764960 MkldnnConvolutionBackward 5741764960->5741756088 5741765520 ReluBackward1 5741765520->5741764960 5741765912 NativeBatchNormBackward 5741765912->5741765520 5741765632 MkldnnConvolutionBackward 5741765632->5741765912 5741755472 ReluBackward1 5741755472->5741754296 5741755472->5741765632 5741766136 AddBackward0 5741766136->5741755472 5741767424 NativeBatchNormBackward 5741767424->5741766136 5741768432 MkldnnConvolutionBackward 5741768432->5741767424 5741768544 ReluBackward1 5741768544->5741768432 5741768600 NativeBatchNormBackward 5741768600->5741768544 5741834368 MkldnnConvolutionBackward 5741834368->5741768600 5741834480 ReluBackward1 5741834480->5741834368 5741834592 NativeBatchNormBackward 5741834592->5741834480 5741834704 MkldnnConvolutionBackward 5741834704->5741834592 5741768152 ReluBackward1 5741768152->5741766136 5741768152->5741834704 5741834872 AddBackward0 5741834872->5741768152 5741834984 NativeBatchNormBackward 5741834984->5741834872 5741835152 MkldnnConvolutionBackward 5741835152->5741834984 5741835264 ReluBackward1 5741835264->5741835152 5741835376 NativeBatchNormBackward 5741835376->5741835264 5741835488 MkldnnConvolutionBackward 5741835488->5741835376 5741835600 ReluBackward1 5741835600->5741835488 5741835712 NativeBatchNormBackward 5741835712->5741835600 5741835824 MkldnnConvolutionBackward 5741835824->5741835712 5741835040 ReluBackward1 5741835040->5741834872 5741835040->5741835824 5741835992 AddBackward0 5741835992->5741835040 5741836104 NativeBatchNormBackward 5741836104->5741835992 5741836272 MkldnnConvolutionBackward 5741836272->5741836104 5741836384 ReluBackward1 5741836384->5741836272 5741836496 NativeBatchNormBackward 5741836496->5741836384 5741836608 MkldnnConvolutionBackward 5741836608->5741836496 5741836720 ReluBackward1 5741836720->5741836608 5741836832 NativeBatchNormBackward 5741836832->5741836720 5741836944 MkldnnConvolutionBackward 5741836944->5741836832 5741836160 ReluBackward1 5741836160->5741835992 5741836160->5741836944 5741837112 AddBackward0 5741837112->5741836160 5741837224 NativeBatchNormBackward 5741837224->5741837112 5741837392 MkldnnConvolutionBackward 5741837392->5741837224 5741837504 ReluBackward1 5741837504->5741837392 5741837616 NativeBatchNormBackward 5741837616->5741837504 5741837728 MkldnnConvolutionBackward 5741837728->5741837616 5741837840 ReluBackward1 5741837840->5741837728 5741837952 NativeBatchNormBackward 5741837952->5741837840 5741838120 MkldnnConvolutionBackward 5741838120->5741837952 5741838232 ReluBackward1 5741838232->5741838120 5741837448 MkldnnConvolutionBackward 5741838232->5741837448 5741842504 AddBackward0 5741842504->5741838232 5741842672 NativeBatchNormBackward 5741842672->5741842504 5741842840 MkldnnConvolutionBackward 5741842840->5741842672 5741842952 ReluBackward1 5741842952->5741842840 5741843064 NativeBatchNormBackward 5741843064->5741842952 5741843232 MkldnnConvolutionBackward 5741843232->5741843064 5741843344 ReluBackward1 5741843344->5741843232 5741843456 NativeBatchNormBackward 5741843456->5741843344 5741843624 MkldnnConvolutionBackward 5741843624->5741843456 5741842728 ReluBackward1 5741842728->5741842504 5741842728->5741843624 5741843792 AddBackward0 5741843792->5741842728 5741843960 NativeBatchNormBackward 5741843960->5741843792 5741844128 MkldnnConvolutionBackward 5741844128->5741843960 5741844240 ReluBackward1 5741844240->5741844128 5741844352 NativeBatchNormBackward 5741844352->5741844240 5741844520 MkldnnConvolutionBackward 5741844520->5741844352 5741844632 ReluBackward1 5741844632->5741844520 5741844744 NativeBatchNormBackward 5741844744->5741844632 5741844912 MkldnnConvolutionBackward 5741844912->5741844744 5741844016 ReluBackward1 5741844016->5741843792 5741844016->5741844912 5741845080 AddBackward0 5741845080->5741844016 5741845248 NativeBatchNormBackward 5741845248->5741845080 5741845416 MkldnnConvolutionBackward 5741845416->5741845248 5741845528 ReluBackward1 5741845528->5741845416 5741845640 NativeBatchNormBackward 5741845640->5741845528 5741845808 MkldnnConvolutionBackward 5741845808->5741845640 5741845920 ReluBackward1 5741845920->5741845808 5741846032 NativeBatchNormBackward 5741846032->5741845920 5741846200 MkldnnConvolutionBackward 5741846200->5741846032 5741845304 ReluBackward1 5741845304->5741845080 5741845304->5741846200 5741846368 AddBackward0 5741846368->5741845304 5741854792 NativeBatchNormBackward 5741854792->5741846368 5741854960 MkldnnConvolutionBackward 5741854960->5741854792 5741855072 ReluBackward1 5741855072->5741854960 5741855184 NativeBatchNormBackward 5741855184->5741855072 5741855352 MkldnnConvolutionBackward 5741855352->5741855184 5741855464 ReluBackward1 5741855464->5741855352 5741855576 NativeBatchNormBackward 5741855576->5741855464 5741855744 MkldnnConvolutionBackward 5741855744->5741855576 5741855856 ReluBackward1 5741855856->5741855744 5660119448 MkldnnConvolutionBackward 5741855856->5660119448 5741855968 AddBackward0 5741855968->5741855856 5741856136 NativeBatchNormBackward 5741856136->5741855968 5741856304 MkldnnConvolutionBackward 5741856304->5741856136 5741856416 ReluBackward1 5741856416->5741856304 5741856528 NativeBatchNormBackward 5741856528->5741856416 5741856696 MkldnnConvolutionBackward 5741856696->5741856528 5741856808 ReluBackward1 5741856808->5741856696 5741856920 NativeBatchNormBackward 5741856920->5741856808 5660119840 MkldnnConvolutionBackward 5660119840->5741856920 5741856192 ReluBackward1 5741856192->5741855968 5741856192->5660119840 5660119336 AddBackward0 5660119336->5741856192 5660119504 NativeBatchNormBackward 5660119504->5660119336 5741857088 MkldnnConvolutionBackward 5741857088->5660119504 5741857200 ReluBackward1 5741857200->5741857088 5741857312 NativeBatchNormBackward 5741857312->5741857200 5741857424 MkldnnConvolutionBackward 5741857424->5741857312 5741857536 ReluBackward1 5741857536->5741857424 5741857648 NativeBatchNormBackward 5741857648->5741857536 5741857816 MkldnnConvolutionBackward 5741857816->5741857648 5660119112 ReluBackward1 5660119112->5660119336 5660119112->5741857816 5741857984 AddBackward0 5741857984->5660119112 5741858152 NativeBatchNormBackward 5741858152->5741857984 5741858320 MkldnnConvolutionBackward 5741858320->5741858152 5741858432 ReluBackward1 5741858432->5741858320 5741858544 NativeBatchNormBackward 5741858544->5741858432 5741858712 MkldnnConvolutionBackward 5741858712->5741858544 5741862984 ReluBackward1 5741862984->5741858712 5741863096 NativeBatchNormBackward 5741863096->5741862984 5741863264 MkldnnConvolutionBackward 5741863264->5741863096 5741863376 MaxPool2DWithIndicesBackward 5741863376->5741863264 5741858376 MkldnnConvolutionBackward 5741863376->5741858376 5741863488 ReluBackward1 5741863488->5741863376 5741863656 NativeBatchNormBackward 5741863656->5741863488 5741863824 MkldnnConvolutionBackward 5741863824->5741863656 5741768488 (64, 3, 7, 7) 5741768488->5741863824 5741768320 (64) 5741768320->5741863656 5741768376 (64) 5741768376->5741863656 5741767816 (64, 64, 1, 1) 5741767816->5741863264 5741767592 (64) 5741767592->5741863096 5741767648 (64) 5741767648->5741863096 5741767256 (64, 64, 3, 3) 5741767256->5741858712 5741767032 (64) 5741767032->5741858544 5741767088 (64) 5741767088->5741858544 5741766696 (256, 64, 1, 1) 5741766696->5741858320 5741766472 (256) 5741766472->5741858152 5741766528 (256) 5741766528->5741858152 5741858208 NativeBatchNormBackward 5741858208->5741857984 5741858376->5741858208 5741766864 (256, 64, 1, 1) 5741766864->5741858376 5741766752 (256) 5741766752->5741858208 5741767144 (256) 5741767144->5741858208 5741765968 (64, 256, 1, 1) 5741765968->5741857816 5741765800 (64) 5741765800->5741857648 5741765856 (64) 5741765856->5741857648 5741765464 (64, 64, 3, 3) 5741765464->5741857424 5741765240 (64) 5741765240->5741857312 5741765296 (64) 5741765296->5741857312 5741764904 (256, 64, 1, 1) 5741764904->5741857088 5741764680 (256) 5741764680->5660119504 5741764736 (256) 5741764736->5660119504 5741755920 (64, 256, 1, 1) 5741755920->5660119840 5741755752 (64) 5741755752->5741856920 5741755808 (64) 5741755808->5741856920 5741755416 (64, 64, 3, 3) 5741755416->5741856696 5741755192 (64) 5741755192->5741856528 5741755248 (64) 5741755248->5741856528 5741754856 (256, 64, 1, 1) 5741754856->5741856304 5741754632 (256) 5741754632->5741856136 5741754688 (256) 5741754688->5741856136 5741754128 (128, 256, 1, 1) 5741754128->5741855744 5741753904 (128) 5741753904->5741855576 5741753960 (128) 5741753960->5741855576 5741753568 (128, 128, 3, 3) 5741753568->5741855352 5741753344 (128) 5741753344->5741855184 5741753400 (128) 5741753400->5741855184 5741753008 (512, 128, 1, 1) 5741753008->5741854960 5741752784 (512) 5741752784->5741854792 5741752840 (512) 5741752840->5741854792 5741854848 NativeBatchNormBackward 5741854848->5741846368 5660119448->5741854848 5741753176 (512, 256, 1, 1) 5741753176->5660119448 5741753064 (512) 5741753064->5741854848 5741753456 (512) 5741753456->5741854848 5741752216 (128, 512, 1, 1) 5741752216->5741846200 5741752048 (128) 5741752048->5741846032 5741752104 (128) 5741752104->5741846032 5741751768 (128, 128, 3, 3) 5741751768->5741845808 5741751544 (128) 5741751544->5741845640 5741751600 (128) 5741751600->5741845640 5741751264 (512, 128, 1, 1) 5741751264->5741845416 5741751040 (512) 5741751040->5741845248 5741751096 (512) 5741751096->5741845248 5741750592 (128, 512, 1, 1) 5741750592->5741844912 5741750424 (128) 5741750424->5741844744 5741750480 (128) 5741750480->5741844744 5741750144 (128, 128, 3, 3) 5741750144->5741844520 5741749920 (128) 5741749920->5741844352 5741749976 (128) 5741749976->5741844352 5741749640 (512, 128, 1, 1) 5741749640->5741844128 5741749416 (512) 5741749416->5741843960 5741749472 (512) 5741749472->5741843960 5741748968 (128, 512, 1, 1) 5741748968->5741843624 5741748800 (128) 5741748800->5741843456 5741748856 (128) 5741748856->5741843456 5741748520 (128, 128, 3, 3) 5741748520->5741843232 5741748296 (128) 5741748296->5741843064 5741748352 (128) 5741748352->5741843064 5741743856 (512, 128, 1, 1) 5741743856->5741842840 5741743632 (512) 5741743632->5741842672 5741743688 (512) 5741743688->5741842672 5741743184 (256, 512, 1, 1) 5741743184->5741838120 5741742960 (256) 5741742960->5741837952 5741743016 (256) 5741743016->5741837952 5741742680 (256, 256, 3, 3) 5741742680->5741837728 5741742456 (256) 5741742456->5741837616 5741742512 (256) 5741742512->5741837616 5741742176 (1024, 256, 1, 1) 5741742176->5741837392 5741741952 (1024) 5741741952->5741837224 5741742008 (1024) 5741742008->5741837224 5741837280 NativeBatchNormBackward 5741837280->5741837112 5741837448->5741837280 5741742344 (1024, 512, 1, 1) 5741742344->5741837448 5741742232 (1024) 5741742232->5741837280 5741742568 (1024) 5741742568->5741837280 5741741504 (256, 1024, 1, 1) 5741741504->5741836944 5741741336 (256) 5741741336->5741836832 5741741392 (256) 5741741392->5741836832 5741741056 (256, 256, 3, 3) 5741741056->5741836608 5741740832 (256) 5741740832->5741836496 5741740888 (256) 5741740888->5741836496 5741740552 (1024, 256, 1, 1) 5741740552->5741836272 5741740328 (1024) 5741740328->5741836104 5741740384 (1024) 5741740384->5741836104 5741735720 (256, 1024, 1, 1) 5741735720->5741835824 4982833104 (256) 4982833104->5741835712 5741735608 (256) 5741735608->5741835712 4900956592 (256, 256, 3, 3) 4900956592->5741835488 5036560568 (256) 5036560568->5741835376 5036560680 (256) 5036560680->5741835376 4875345032 (1024, 256, 1, 1) 4875345032->5741835152 5547555584 (1024) 5547555584->5741834984 5547555696 (1024) 5547555696->5741834984 5660022936 (256, 1024, 1, 1) 5660022936->5741834704 5660021592 (256) 5660021592->5741834592 5660021200 (256) 5660021200->5741834592 5660022152 (256, 256, 3, 3) 5660022152->5741834368 5660022544 (256) 5660022544->5741768600 5660021088 (256) 5660021088->5741768600 5660022320 (1024, 256, 1, 1) 5660022320->5741768432 5660024336 (1024) 5660024336->5741767424 5660023160 (1024) 5660023160->5741767424 5660023552 (256, 1024, 1, 1) 5660023552->5741765632 5660024224 (256) 5660024224->5741765912 5660023944 (256) 5660023944->5741765912 5660024392 (256, 256, 3, 3) 5660024392->5741764960 5741735440 (256) 5741735440->5741756088 5741735496 (256) 5741735496->5741756088 5741735104 (1024, 256, 1, 1) 5741735104->5741755864 5741734880 (1024) 5741734880->5741755304 5741734936 (1024) 5741734936->5741755304 5741734376 (256, 1024, 1, 1) 5741734376->5741754520 5741734208 (256) 5741734208->5741754016 5741734264 (256) 5741734264->5741754016 5741733872 (256, 256, 3, 3) 5741733872->5741751936 5741733648 (256) 5741733648->5741752160 5741733704 (256) 5741733704->5741752160 5741733312 (1024, 256, 1, 1) 5741733312->5741751320 5741733088 (1024) 5741733088->5741750928 5741733144 (1024) 5741733144->5741750928 5741732584 (512, 1024, 1, 1) 5741732584->5741750032 5741732360 (512) 5741732360->5741749136 5741732416 (512) 5741732416->5741749136 5741732024 (512, 512, 3, 3) 5741732024->5741749024 5741731736 (512) 5741731736->5741748576 5741731792 (512) 5741731792->5741748576 5741731400 (2048, 512, 1, 1) 5741731400->5741743352 5741731176 (2048) 5741731176->5741743520 5741731232 (2048) 5741731232->5741743520 5741743240 NativeBatchNormBackward 5741743240->5741743072 5741743912->5741743240 5741731568 (2048, 1024, 1, 1) 5741731568->5741743912 5638186152 (2048) 5638186152->5741743240 5741731288 (2048) 5741731288->5741743240 5741730672 (512, 2048, 1, 1) 5741730672->5741741840 5741730504 (512) 5741730504->5741741560 5741730560 (512) 5741730560->5741741560 5741730168 (512, 512, 3, 3) 5741730168->5741740944 5741729944 (512) 5741729944->5741740440 5741730000 (512) 5741730000->5741740440 5741729608 (2048, 512, 1, 1) 5741729608->5741735664 5741729384 (2048) 5741729384->5741735160 5741729440 (2048) 5741729440->5741735160 5741728936 (512, 2048, 1, 1) 5741728936->5741734320 5741728768 (512) 5741728768->5741733928 5741728824 (512) 5741728824->5741733928 5741728488 (512, 512, 3, 3) 5741728488->5741733200 5741728264 (512) 5741728264->5741732752 5741728320 (512) 5741728320->5741732752 5741727984 (2048, 512, 1, 1) 5741727984->5741732080 5660123088 (2048) 5660123088->4900957208 5741727816 (2048) 5741727816->4900957208 5660022880 TBackward 5660022880->5660024784 5660122360 (1000, 2048) 5660122360->5660022880

In [97]:
resnet50 = build_resnet('resnet50', 'classic')


Version: {'net': <class '__main__.ResNet'>, 'block': <class '__main__.Bottleneck'>, 'layers': [3, 4, 6, 3], 'num_classes': 1000}
Config: {'conv': <class 'torch.nn.modules.conv.Conv2d'>, 'conv_init': 'fan_out', 'nonlinearity': 'relu', 'last_bn_0_init': False, 'activation': <function <lambda> at 0x122955840>}

In [96]:
summary(resnet50, (3, 224, 224));


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]           4,096
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]          16,384
      BatchNorm2d-12          [-1, 256, 56, 56]             512
           Conv2d-13          [-1, 256, 56, 56]          16,384
      BatchNorm2d-14          [-1, 256, 56, 56]             512
             ReLU-15          [-1, 256, 56, 56]               0
       Bottleneck-16          [-1, 256, 56, 56]               0
           Conv2d-17           [-1, 64, 56, 56]          16,384
      BatchNorm2d-18           [-1, 64, 56, 56]             128
             ReLU-19           [-1, 64, 56, 56]               0
           Conv2d-20           [-1, 64, 56, 56]          36,864
      BatchNorm2d-21           [-1, 64, 56, 56]             128
             ReLU-22           [-1, 64, 56, 56]               0
           Conv2d-23          [-1, 256, 56, 56]          16,384
      BatchNorm2d-24          [-1, 256, 56, 56]             512
             ReLU-25          [-1, 256, 56, 56]               0
       Bottleneck-26          [-1, 256, 56, 56]               0
           Conv2d-27           [-1, 64, 56, 56]          16,384
      BatchNorm2d-28           [-1, 64, 56, 56]             128
             ReLU-29           [-1, 64, 56, 56]               0
           Conv2d-30           [-1, 64, 56, 56]          36,864
      BatchNorm2d-31           [-1, 64, 56, 56]             128
             ReLU-32           [-1, 64, 56, 56]               0
           Conv2d-33          [-1, 256, 56, 56]          16,384
      BatchNorm2d-34          [-1, 256, 56, 56]             512
             ReLU-35          [-1, 256, 56, 56]               0
       Bottleneck-36          [-1, 256, 56, 56]               0
           Conv2d-37          [-1, 128, 56, 56]          32,768
      BatchNorm2d-38          [-1, 128, 56, 56]             256
             ReLU-39          [-1, 128, 56, 56]               0
           Conv2d-40          [-1, 128, 28, 28]         147,456
      BatchNorm2d-41          [-1, 128, 28, 28]             256
             ReLU-42          [-1, 128, 28, 28]               0
           Conv2d-43          [-1, 512, 28, 28]          65,536
      BatchNorm2d-44          [-1, 512, 28, 28]           1,024
           Conv2d-45          [-1, 512, 28, 28]         131,072
      BatchNorm2d-46          [-1, 512, 28, 28]           1,024
             ReLU-47          [-1, 512, 28, 28]               0
       Bottleneck-48          [-1, 512, 28, 28]               0
           Conv2d-49          [-1, 128, 28, 28]          65,536
      BatchNorm2d-50          [-1, 128, 28, 28]             256
             ReLU-51          [-1, 128, 28, 28]               0
           Conv2d-52          [-1, 128, 28, 28]         147,456
      BatchNorm2d-53          [-1, 128, 28, 28]             256
             ReLU-54          [-1, 128, 28, 28]               0
           Conv2d-55          [-1, 512, 28, 28]          65,536
      BatchNorm2d-56          [-1, 512, 28, 28]           1,024
             ReLU-57          [-1, 512, 28, 28]               0
       Bottleneck-58          [-1, 512, 28, 28]               0
           Conv2d-59          [-1, 128, 28, 28]          65,536
      BatchNorm2d-60          [-1, 128, 28, 28]             256
             ReLU-61          [-1, 128, 28, 28]               0
           Conv2d-62          [-1, 128, 28, 28]         147,456
      BatchNorm2d-63          [-1, 128, 28, 28]             256
             ReLU-64          [-1, 128, 28, 28]               0
           Conv2d-65          [-1, 512, 28, 28]          65,536
      BatchNorm2d-66          [-1, 512, 28, 28]           1,024
             ReLU-67          [-1, 512, 28, 28]               0
       Bottleneck-68          [-1, 512, 28, 28]               0
           Conv2d-69          [-1, 128, 28, 28]          65,536
      BatchNorm2d-70          [-1, 128, 28, 28]             256
             ReLU-71          [-1, 128, 28, 28]               0
           Conv2d-72          [-1, 128, 28, 28]         147,456
      BatchNorm2d-73          [-1, 128, 28, 28]             256
             ReLU-74          [-1, 128, 28, 28]               0
           Conv2d-75          [-1, 512, 28, 28]          65,536
      BatchNorm2d-76          [-1, 512, 28, 28]           1,024
             ReLU-77          [-1, 512, 28, 28]               0
       Bottleneck-78          [-1, 512, 28, 28]               0
           Conv2d-79          [-1, 256, 28, 28]         131,072
      BatchNorm2d-80          [-1, 256, 28, 28]             512
             ReLU-81          [-1, 256, 28, 28]               0
           Conv2d-82          [-1, 256, 14, 14]         589,824
      BatchNorm2d-83          [-1, 256, 14, 14]             512
             ReLU-84          [-1, 256, 14, 14]               0
           Conv2d-85         [-1, 1024, 14, 14]         262,144
      BatchNorm2d-86         [-1, 1024, 14, 14]           2,048
           Conv2d-87         [-1, 1024, 14, 14]         524,288
      BatchNorm2d-88         [-1, 1024, 14, 14]           2,048
             ReLU-89         [-1, 1024, 14, 14]               0
       Bottleneck-90         [-1, 1024, 14, 14]               0
           Conv2d-91          [-1, 256, 14, 14]         262,144
      BatchNorm2d-92          [-1, 256, 14, 14]             512
             ReLU-93          [-1, 256, 14, 14]               0
           Conv2d-94          [-1, 256, 14, 14]         589,824
      BatchNorm2d-95          [-1, 256, 14, 14]             512
             ReLU-96          [-1, 256, 14, 14]               0
           Conv2d-97         [-1, 1024, 14, 14]         262,144
      BatchNorm2d-98         [-1, 1024, 14, 14]           2,048
             ReLU-99         [-1, 1024, 14, 14]               0
      Bottleneck-100         [-1, 1024, 14, 14]               0
          Conv2d-101          [-1, 256, 14, 14]         262,144
     BatchNorm2d-102          [-1, 256, 14, 14]             512
            ReLU-103          [-1, 256, 14, 14]               0
          Conv2d-104          [-1, 256, 14, 14]         589,824
     BatchNorm2d-105          [-1, 256, 14, 14]             512
            ReLU-106          [-1, 256, 14, 14]               0
          Conv2d-107         [-1, 1024, 14, 14]         262,144
     BatchNorm2d-108         [-1, 1024, 14, 14]           2,048
            ReLU-109         [-1, 1024, 14, 14]               0
      Bottleneck-110         [-1, 1024, 14, 14]               0
          Conv2d-111          [-1, 256, 14, 14]         262,144
     BatchNorm2d-112          [-1, 256, 14, 14]             512
            ReLU-113          [-1, 256, 14, 14]               0
          Conv2d-114          [-1, 256, 14, 14]         589,824
     BatchNorm2d-115          [-1, 256, 14, 14]             512
            ReLU-116          [-1, 256, 14, 14]               0
          Conv2d-117         [-1, 1024, 14, 14]         262,144
     BatchNorm2d-118         [-1, 1024, 14, 14]           2,048
            ReLU-119         [-1, 1024, 14, 14]               0
      Bottleneck-120         [-1, 1024, 14, 14]               0
          Conv2d-121          [-1, 256, 14, 14]         262,144
     BatchNorm2d-122          [-1, 256, 14, 14]             512
            ReLU-123          [-1, 256, 14, 14]               0
          Conv2d-124          [-1, 256, 14, 14]         589,824
     BatchNorm2d-125          [-1, 256, 14, 14]             512
            ReLU-126          [-1, 256, 14, 14]               0
          Conv2d-127         [-1, 1024, 14, 14]         262,144
     BatchNorm2d-128         [-1, 1024, 14, 14]           2,048
            ReLU-129         [-1, 1024, 14, 14]               0
      Bottleneck-130         [-1, 1024, 14, 14]               0
          Conv2d-131          [-1, 256, 14, 14]         262,144
     BatchNorm2d-132          [-1, 256, 14, 14]             512
            ReLU-133          [-1, 256, 14, 14]               0
          Conv2d-134          [-1, 256, 14, 14]         589,824
     BatchNorm2d-135          [-1, 256, 14, 14]             512
            ReLU-136          [-1, 256, 14, 14]               0
          Conv2d-137         [-1, 1024, 14, 14]         262,144
     BatchNorm2d-138         [-1, 1024, 14, 14]           2,048
            ReLU-139         [-1, 1024, 14, 14]               0
      Bottleneck-140         [-1, 1024, 14, 14]               0
          Conv2d-141          [-1, 512, 14, 14]         524,288
     BatchNorm2d-142          [-1, 512, 14, 14]           1,024
            ReLU-143          [-1, 512, 14, 14]               0
          Conv2d-144            [-1, 512, 7, 7]       2,359,296
     BatchNorm2d-145            [-1, 512, 7, 7]           1,024
            ReLU-146            [-1, 512, 7, 7]               0
          Conv2d-147           [-1, 2048, 7, 7]       1,048,576
     BatchNorm2d-148           [-1, 2048, 7, 7]           4,096
          Conv2d-149           [-1, 2048, 7, 7]       2,097,152
     BatchNorm2d-150           [-1, 2048, 7, 7]           4,096
            ReLU-151           [-1, 2048, 7, 7]               0
      Bottleneck-152           [-1, 2048, 7, 7]               0
          Conv2d-153            [-1, 512, 7, 7]       1,048,576
     BatchNorm2d-154            [-1, 512, 7, 7]           1,024
            ReLU-155            [-1, 512, 7, 7]               0
          Conv2d-156            [-1, 512, 7, 7]       2,359,296
     BatchNorm2d-157            [-1, 512, 7, 7]           1,024
            ReLU-158            [-1, 512, 7, 7]               0
          Conv2d-159           [-1, 2048, 7, 7]       1,048,576
     BatchNorm2d-160           [-1, 2048, 7, 7]           4,096
            ReLU-161           [-1, 2048, 7, 7]               0
      Bottleneck-162           [-1, 2048, 7, 7]               0
          Conv2d-163            [-1, 512, 7, 7]       1,048,576
     BatchNorm2d-164            [-1, 512, 7, 7]           1,024
            ReLU-165            [-1, 512, 7, 7]               0
          Conv2d-166            [-1, 512, 7, 7]       2,359,296
     BatchNorm2d-167            [-1, 512, 7, 7]           1,024
            ReLU-168            [-1, 512, 7, 7]               0
          Conv2d-169           [-1, 2048, 7, 7]       1,048,576
     BatchNorm2d-170           [-1, 2048, 7, 7]           4,096
            ReLU-171           [-1, 2048, 7, 7]               0
      Bottleneck-172           [-1, 2048, 7, 7]               0
AdaptiveAvgPool2d-173           [-1, 2048, 1, 1]               0
          Linear-174                 [-1, 1000]       2,049,000
================================================================
Total params: 25,557,032
Trainable params: 25,557,032
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 286.56
Params size (MB): 97.49
Estimated Total Size (MB): 384.62
----------------------------------------------------------------

In [95]:
build_resnet('resnet50', 'fanin');


Version: {'net': <class '__main__.ResNet'>, 'block': <class '__main__.Bottleneck'>, 'layers': [3, 4, 6, 3], 'num_classes': 1000}
Config: {'conv': <class 'torch.nn.modules.conv.Conv2d'>, 'conv_init': 'fan_in', 'nonlinearity': 'relu', 'last_bn_0_init': False, 'activation': <function <lambda> at 0x12298c7b8>}

In [ ]: