In [2]:
from nupic.research.frameworks.dynamic_sparse.networks import *
from nupic.research.frameworks.dynamic_sparse.models import *

In [8]:
config = dict(
    device="cpu",
    network="GSCHeb",
    optim_alg="SGD",
    momentum=0,  # 0.9,
    learning_rate=0.01,  # 0.1,
    weight_decay=0.01,  # 1e-4,
    lr_scheduler="MultiStepLR",
    lr_milestones=[30, 60, 90],
    lr_gamma=0.9,  # 0.1,
    use_kwinners=True,
    model="DSNNWeightedMag",
    on_perc=0.04,
    weight_prune_perc=0.3,

)

In [18]:
network = GSCHeb(config)
model = DSNNWeightedMag(network, config)
model.setup()
model.sparse_modules


Out[18]:
[{'name': 'Conv2d', 'index': 2, 'shape': torch.Size([64, 1, 5, 5]), 'on_perc': 0.04, 'hebbian_prune': None, 'weight_prune': 0.3, 'num_params': 64.0},
 {'name': 'Conv2d', 'index': 6, 'shape': torch.Size([64, 64, 5, 5]), 'on_perc': 0.04, 'hebbian_prune': None, 'weight_prune': 0.3, 'num_params': 4096.0},
 {'name': 'DSLinear', 'index': 13, 'shape': torch.Size([1000, 1600]), 'on_perc': 0.04, 'hebbian_prune': None, 'weight_prune': 0.3, 'num_params': 64000.0},
 {'name': 'DSLinear', 'index': 17, 'shape': torch.Size([12, 1000]), 'on_perc': 0.04, 'hebbian_prune': None, 'weight_prune': 0.3, 'num_params': 480.0}]

In [21]:
modules = model.sparse_modules
modules[0].get_coactivations()


Out[21]:
tensor([[[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        ...,


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]]])

In [15]:
network = GSCHeb_v0(config)
model = DSNNWeightedMag(network, config)
model.setup()
model.sparse_modules


Out[15]:
[{'name': 'Conv2d', 'index': 2, 'shape': torch.Size([64, 1, 5, 5]), 'on_perc': 0.04, 'hebbian_prune': None, 'weight_prune': 0.3, 'num_params': 64.0},
 {'name': 'Conv2d', 'index': 6, 'shape': torch.Size([64, 64, 5, 5]), 'on_perc': 0.04, 'hebbian_prune': None, 'weight_prune': 0.3, 'num_params': 4096.0},
 {'name': 'Linear', 'index': 12, 'shape': torch.Size([1000, 1600]), 'on_perc': 0.04, 'hebbian_prune': None, 'weight_prune': 0.3, 'num_params': 64000.0},
 {'name': 'Linear', 'index': 15, 'shape': torch.Size([12, 1000]), 'on_perc': 0.04, 'hebbian_prune': None, 'weight_prune': 0.3, 'num_params': 480.0}]

In [17]:
modules = model.sparse_modules
modules[0].m.coactivations


Out[17]:
tensor([[[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        ...,


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]]])

In [22]:
net = resnet152()

In [23]:
from torchsummary import summary

In [24]:
summary(resnet152(), input_size=(3,32,32))


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 16, 32, 32]             448
       BatchNorm2d-2           [-1, 16, 32, 32]              32
              ReLU-3           [-1, 16, 32, 32]               0
            Conv2d-4           [-1, 16, 32, 32]             272
       BatchNorm2d-5           [-1, 16, 32, 32]              32
              ReLU-6           [-1, 16, 32, 32]               0
            Conv2d-7           [-1, 16, 32, 32]           2,320
       BatchNorm2d-8           [-1, 16, 32, 32]              32
              ReLU-9           [-1, 16, 32, 32]               0
           Conv2d-10           [-1, 64, 32, 32]           1,088
      BatchNorm2d-11           [-1, 64, 32, 32]             128
           Conv2d-12           [-1, 64, 32, 32]           1,088
      BatchNorm2d-13           [-1, 64, 32, 32]             128
             ReLU-14           [-1, 64, 32, 32]               0
       Bottleneck-15           [-1, 64, 32, 32]               0
           Conv2d-16           [-1, 16, 32, 32]           1,040
      BatchNorm2d-17           [-1, 16, 32, 32]              32
             ReLU-18           [-1, 16, 32, 32]               0
           Conv2d-19           [-1, 16, 32, 32]           2,320
      BatchNorm2d-20           [-1, 16, 32, 32]              32
             ReLU-21           [-1, 16, 32, 32]               0
           Conv2d-22           [-1, 64, 32, 32]           1,088
      BatchNorm2d-23           [-1, 64, 32, 32]             128
             ReLU-24           [-1, 64, 32, 32]               0
       Bottleneck-25           [-1, 64, 32, 32]               0
           Conv2d-26           [-1, 16, 32, 32]           1,040
      BatchNorm2d-27           [-1, 16, 32, 32]              32
             ReLU-28           [-1, 16, 32, 32]               0
           Conv2d-29           [-1, 16, 32, 32]           2,320
      BatchNorm2d-30           [-1, 16, 32, 32]              32
             ReLU-31           [-1, 16, 32, 32]               0
           Conv2d-32           [-1, 64, 32, 32]           1,088
      BatchNorm2d-33           [-1, 64, 32, 32]             128
             ReLU-34           [-1, 64, 32, 32]               0
       Bottleneck-35           [-1, 64, 32, 32]               0
           Conv2d-36           [-1, 32, 32, 32]           2,080
      BatchNorm2d-37           [-1, 32, 32, 32]              64
             ReLU-38           [-1, 32, 32, 32]               0
           Conv2d-39           [-1, 32, 16, 16]           9,248
      BatchNorm2d-40           [-1, 32, 16, 16]              64
             ReLU-41           [-1, 32, 16, 16]               0
           Conv2d-42          [-1, 128, 16, 16]           4,224
      BatchNorm2d-43          [-1, 128, 16, 16]             256
           Conv2d-44          [-1, 128, 16, 16]           8,320
      BatchNorm2d-45          [-1, 128, 16, 16]             256
             ReLU-46          [-1, 128, 16, 16]               0
       Bottleneck-47          [-1, 128, 16, 16]               0
           Conv2d-48           [-1, 32, 16, 16]           4,128
      BatchNorm2d-49           [-1, 32, 16, 16]              64
             ReLU-50           [-1, 32, 16, 16]               0
           Conv2d-51           [-1, 32, 16, 16]           9,248
      BatchNorm2d-52           [-1, 32, 16, 16]              64
             ReLU-53           [-1, 32, 16, 16]               0
           Conv2d-54          [-1, 128, 16, 16]           4,224
      BatchNorm2d-55          [-1, 128, 16, 16]             256
             ReLU-56          [-1, 128, 16, 16]               0
       Bottleneck-57          [-1, 128, 16, 16]               0
           Conv2d-58           [-1, 32, 16, 16]           4,128
      BatchNorm2d-59           [-1, 32, 16, 16]              64
             ReLU-60           [-1, 32, 16, 16]               0
           Conv2d-61           [-1, 32, 16, 16]           9,248
      BatchNorm2d-62           [-1, 32, 16, 16]              64
             ReLU-63           [-1, 32, 16, 16]               0
           Conv2d-64          [-1, 128, 16, 16]           4,224
      BatchNorm2d-65          [-1, 128, 16, 16]             256
             ReLU-66          [-1, 128, 16, 16]               0
       Bottleneck-67          [-1, 128, 16, 16]               0
           Conv2d-68           [-1, 32, 16, 16]           4,128
      BatchNorm2d-69           [-1, 32, 16, 16]              64
             ReLU-70           [-1, 32, 16, 16]               0
           Conv2d-71           [-1, 32, 16, 16]           9,248
      BatchNorm2d-72           [-1, 32, 16, 16]              64
             ReLU-73           [-1, 32, 16, 16]               0
           Conv2d-74          [-1, 128, 16, 16]           4,224
      BatchNorm2d-75          [-1, 128, 16, 16]             256
             ReLU-76          [-1, 128, 16, 16]               0
       Bottleneck-77          [-1, 128, 16, 16]               0
           Conv2d-78           [-1, 32, 16, 16]           4,128
      BatchNorm2d-79           [-1, 32, 16, 16]              64
             ReLU-80           [-1, 32, 16, 16]               0
           Conv2d-81           [-1, 32, 16, 16]           9,248
      BatchNorm2d-82           [-1, 32, 16, 16]              64
             ReLU-83           [-1, 32, 16, 16]               0
           Conv2d-84          [-1, 128, 16, 16]           4,224
      BatchNorm2d-85          [-1, 128, 16, 16]             256
             ReLU-86          [-1, 128, 16, 16]               0
       Bottleneck-87          [-1, 128, 16, 16]               0
           Conv2d-88           [-1, 32, 16, 16]           4,128
      BatchNorm2d-89           [-1, 32, 16, 16]              64
             ReLU-90           [-1, 32, 16, 16]               0
           Conv2d-91           [-1, 32, 16, 16]           9,248
      BatchNorm2d-92           [-1, 32, 16, 16]              64
             ReLU-93           [-1, 32, 16, 16]               0
           Conv2d-94          [-1, 128, 16, 16]           4,224
      BatchNorm2d-95          [-1, 128, 16, 16]             256
             ReLU-96          [-1, 128, 16, 16]               0
       Bottleneck-97          [-1, 128, 16, 16]               0
           Conv2d-98           [-1, 32, 16, 16]           4,128
      BatchNorm2d-99           [-1, 32, 16, 16]              64
            ReLU-100           [-1, 32, 16, 16]               0
          Conv2d-101           [-1, 32, 16, 16]           9,248
     BatchNorm2d-102           [-1, 32, 16, 16]              64
            ReLU-103           [-1, 32, 16, 16]               0
          Conv2d-104          [-1, 128, 16, 16]           4,224
     BatchNorm2d-105          [-1, 128, 16, 16]             256
            ReLU-106          [-1, 128, 16, 16]               0
      Bottleneck-107          [-1, 128, 16, 16]               0
          Conv2d-108           [-1, 32, 16, 16]           4,128
     BatchNorm2d-109           [-1, 32, 16, 16]              64
            ReLU-110           [-1, 32, 16, 16]               0
          Conv2d-111           [-1, 32, 16, 16]           9,248
     BatchNorm2d-112           [-1, 32, 16, 16]              64
            ReLU-113           [-1, 32, 16, 16]               0
          Conv2d-114          [-1, 128, 16, 16]           4,224
     BatchNorm2d-115          [-1, 128, 16, 16]             256
            ReLU-116          [-1, 128, 16, 16]               0
      Bottleneck-117          [-1, 128, 16, 16]               0
          Conv2d-118           [-1, 64, 16, 16]           8,256
     BatchNorm2d-119           [-1, 64, 16, 16]             128
            ReLU-120           [-1, 64, 16, 16]               0
          Conv2d-121             [-1, 64, 8, 8]          36,928
     BatchNorm2d-122             [-1, 64, 8, 8]             128
            ReLU-123             [-1, 64, 8, 8]               0
          Conv2d-124            [-1, 256, 8, 8]          16,640
     BatchNorm2d-125            [-1, 256, 8, 8]             512
          Conv2d-126            [-1, 256, 8, 8]          33,024
     BatchNorm2d-127            [-1, 256, 8, 8]             512
            ReLU-128            [-1, 256, 8, 8]               0
      Bottleneck-129            [-1, 256, 8, 8]               0
          Conv2d-130             [-1, 64, 8, 8]          16,448
     BatchNorm2d-131             [-1, 64, 8, 8]             128
            ReLU-132             [-1, 64, 8, 8]               0
          Conv2d-133             [-1, 64, 8, 8]          36,928
     BatchNorm2d-134             [-1, 64, 8, 8]             128
            ReLU-135             [-1, 64, 8, 8]               0
          Conv2d-136            [-1, 256, 8, 8]          16,640
     BatchNorm2d-137            [-1, 256, 8, 8]             512
            ReLU-138            [-1, 256, 8, 8]               0
      Bottleneck-139            [-1, 256, 8, 8]               0
          Conv2d-140             [-1, 64, 8, 8]          16,448
     BatchNorm2d-141             [-1, 64, 8, 8]             128
            ReLU-142             [-1, 64, 8, 8]               0
          Conv2d-143             [-1, 64, 8, 8]          36,928
     BatchNorm2d-144             [-1, 64, 8, 8]             128
            ReLU-145             [-1, 64, 8, 8]               0
          Conv2d-146            [-1, 256, 8, 8]          16,640
     BatchNorm2d-147            [-1, 256, 8, 8]             512
            ReLU-148            [-1, 256, 8, 8]               0
      Bottleneck-149            [-1, 256, 8, 8]               0
          Conv2d-150             [-1, 64, 8, 8]          16,448
     BatchNorm2d-151             [-1, 64, 8, 8]             128
            ReLU-152             [-1, 64, 8, 8]               0
          Conv2d-153             [-1, 64, 8, 8]          36,928
     BatchNorm2d-154             [-1, 64, 8, 8]             128
            ReLU-155             [-1, 64, 8, 8]               0
          Conv2d-156            [-1, 256, 8, 8]          16,640
     BatchNorm2d-157            [-1, 256, 8, 8]             512
            ReLU-158            [-1, 256, 8, 8]               0
      Bottleneck-159            [-1, 256, 8, 8]               0
          Conv2d-160             [-1, 64, 8, 8]          16,448
     BatchNorm2d-161             [-1, 64, 8, 8]             128
            ReLU-162             [-1, 64, 8, 8]               0
          Conv2d-163             [-1, 64, 8, 8]          36,928
     BatchNorm2d-164             [-1, 64, 8, 8]             128
            ReLU-165             [-1, 64, 8, 8]               0
          Conv2d-166            [-1, 256, 8, 8]          16,640
     BatchNorm2d-167            [-1, 256, 8, 8]             512
            ReLU-168            [-1, 256, 8, 8]               0
      Bottleneck-169            [-1, 256, 8, 8]               0
          Conv2d-170             [-1, 64, 8, 8]          16,448
     BatchNorm2d-171             [-1, 64, 8, 8]             128
            ReLU-172             [-1, 64, 8, 8]               0
          Conv2d-173             [-1, 64, 8, 8]          36,928
     BatchNorm2d-174             [-1, 64, 8, 8]             128
            ReLU-175             [-1, 64, 8, 8]               0
          Conv2d-176            [-1, 256, 8, 8]          16,640
     BatchNorm2d-177            [-1, 256, 8, 8]             512
            ReLU-178            [-1, 256, 8, 8]               0
      Bottleneck-179            [-1, 256, 8, 8]               0
          Conv2d-180             [-1, 64, 8, 8]          16,448
     BatchNorm2d-181             [-1, 64, 8, 8]             128
            ReLU-182             [-1, 64, 8, 8]               0
          Conv2d-183             [-1, 64, 8, 8]          36,928
     BatchNorm2d-184             [-1, 64, 8, 8]             128
            ReLU-185             [-1, 64, 8, 8]               0
          Conv2d-186            [-1, 256, 8, 8]          16,640
     BatchNorm2d-187            [-1, 256, 8, 8]             512
            ReLU-188            [-1, 256, 8, 8]               0
      Bottleneck-189            [-1, 256, 8, 8]               0
          Conv2d-190             [-1, 64, 8, 8]          16,448
     BatchNorm2d-191             [-1, 64, 8, 8]             128
            ReLU-192             [-1, 64, 8, 8]               0
          Conv2d-193             [-1, 64, 8, 8]          36,928
     BatchNorm2d-194             [-1, 64, 8, 8]             128
            ReLU-195             [-1, 64, 8, 8]               0
          Conv2d-196            [-1, 256, 8, 8]          16,640
     BatchNorm2d-197            [-1, 256, 8, 8]             512
            ReLU-198            [-1, 256, 8, 8]               0
      Bottleneck-199            [-1, 256, 8, 8]               0
          Conv2d-200             [-1, 64, 8, 8]          16,448
     BatchNorm2d-201             [-1, 64, 8, 8]             128
            ReLU-202             [-1, 64, 8, 8]               0
          Conv2d-203             [-1, 64, 8, 8]          36,928
     BatchNorm2d-204             [-1, 64, 8, 8]             128
            ReLU-205             [-1, 64, 8, 8]               0
          Conv2d-206            [-1, 256, 8, 8]          16,640
     BatchNorm2d-207            [-1, 256, 8, 8]             512
            ReLU-208            [-1, 256, 8, 8]               0
      Bottleneck-209            [-1, 256, 8, 8]               0
          Conv2d-210             [-1, 64, 8, 8]          16,448
     BatchNorm2d-211             [-1, 64, 8, 8]             128
            ReLU-212             [-1, 64, 8, 8]               0
          Conv2d-213             [-1, 64, 8, 8]          36,928
     BatchNorm2d-214             [-1, 64, 8, 8]             128
            ReLU-215             [-1, 64, 8, 8]               0
          Conv2d-216            [-1, 256, 8, 8]          16,640
     BatchNorm2d-217            [-1, 256, 8, 8]             512
            ReLU-218            [-1, 256, 8, 8]               0
      Bottleneck-219            [-1, 256, 8, 8]               0
          Conv2d-220             [-1, 64, 8, 8]          16,448
     BatchNorm2d-221             [-1, 64, 8, 8]             128
            ReLU-222             [-1, 64, 8, 8]               0
          Conv2d-223             [-1, 64, 8, 8]          36,928
     BatchNorm2d-224             [-1, 64, 8, 8]             128
            ReLU-225             [-1, 64, 8, 8]               0
          Conv2d-226            [-1, 256, 8, 8]          16,640
     BatchNorm2d-227            [-1, 256, 8, 8]             512
            ReLU-228            [-1, 256, 8, 8]               0
      Bottleneck-229            [-1, 256, 8, 8]               0
          Conv2d-230             [-1, 64, 8, 8]          16,448
     BatchNorm2d-231             [-1, 64, 8, 8]             128
            ReLU-232             [-1, 64, 8, 8]               0
          Conv2d-233             [-1, 64, 8, 8]          36,928
     BatchNorm2d-234             [-1, 64, 8, 8]             128
            ReLU-235             [-1, 64, 8, 8]               0
          Conv2d-236            [-1, 256, 8, 8]          16,640
     BatchNorm2d-237            [-1, 256, 8, 8]             512
            ReLU-238            [-1, 256, 8, 8]               0
      Bottleneck-239            [-1, 256, 8, 8]               0
          Conv2d-240             [-1, 64, 8, 8]          16,448
     BatchNorm2d-241             [-1, 64, 8, 8]             128
            ReLU-242             [-1, 64, 8, 8]               0
          Conv2d-243             [-1, 64, 8, 8]          36,928
     BatchNorm2d-244             [-1, 64, 8, 8]             128
            ReLU-245             [-1, 64, 8, 8]               0
          Conv2d-246            [-1, 256, 8, 8]          16,640
     BatchNorm2d-247            [-1, 256, 8, 8]             512
            ReLU-248            [-1, 256, 8, 8]               0
      Bottleneck-249            [-1, 256, 8, 8]               0
          Conv2d-250             [-1, 64, 8, 8]          16,448
     BatchNorm2d-251             [-1, 64, 8, 8]             128
            ReLU-252             [-1, 64, 8, 8]               0
          Conv2d-253             [-1, 64, 8, 8]          36,928
     BatchNorm2d-254             [-1, 64, 8, 8]             128
            ReLU-255             [-1, 64, 8, 8]               0
          Conv2d-256            [-1, 256, 8, 8]          16,640
     BatchNorm2d-257            [-1, 256, 8, 8]             512
            ReLU-258            [-1, 256, 8, 8]               0
      Bottleneck-259            [-1, 256, 8, 8]               0
          Conv2d-260             [-1, 64, 8, 8]          16,448
     BatchNorm2d-261             [-1, 64, 8, 8]             128
            ReLU-262             [-1, 64, 8, 8]               0
          Conv2d-263             [-1, 64, 8, 8]          36,928
     BatchNorm2d-264             [-1, 64, 8, 8]             128
            ReLU-265             [-1, 64, 8, 8]               0
          Conv2d-266            [-1, 256, 8, 8]          16,640
     BatchNorm2d-267            [-1, 256, 8, 8]             512
            ReLU-268            [-1, 256, 8, 8]               0
      Bottleneck-269            [-1, 256, 8, 8]               0
          Conv2d-270             [-1, 64, 8, 8]          16,448
     BatchNorm2d-271             [-1, 64, 8, 8]             128
            ReLU-272             [-1, 64, 8, 8]               0
          Conv2d-273             [-1, 64, 8, 8]          36,928
     BatchNorm2d-274             [-1, 64, 8, 8]             128
            ReLU-275             [-1, 64, 8, 8]               0
          Conv2d-276            [-1, 256, 8, 8]          16,640
     BatchNorm2d-277            [-1, 256, 8, 8]             512
            ReLU-278            [-1, 256, 8, 8]               0
      Bottleneck-279            [-1, 256, 8, 8]               0
          Conv2d-280             [-1, 64, 8, 8]          16,448
     BatchNorm2d-281             [-1, 64, 8, 8]             128
            ReLU-282             [-1, 64, 8, 8]               0
          Conv2d-283             [-1, 64, 8, 8]          36,928
     BatchNorm2d-284             [-1, 64, 8, 8]             128
            ReLU-285             [-1, 64, 8, 8]               0
          Conv2d-286            [-1, 256, 8, 8]          16,640
     BatchNorm2d-287            [-1, 256, 8, 8]             512
            ReLU-288            [-1, 256, 8, 8]               0
      Bottleneck-289            [-1, 256, 8, 8]               0
          Conv2d-290             [-1, 64, 8, 8]          16,448
     BatchNorm2d-291             [-1, 64, 8, 8]             128
            ReLU-292             [-1, 64, 8, 8]               0
          Conv2d-293             [-1, 64, 8, 8]          36,928
     BatchNorm2d-294             [-1, 64, 8, 8]             128
            ReLU-295             [-1, 64, 8, 8]               0
          Conv2d-296            [-1, 256, 8, 8]          16,640
     BatchNorm2d-297            [-1, 256, 8, 8]             512
            ReLU-298            [-1, 256, 8, 8]               0
      Bottleneck-299            [-1, 256, 8, 8]               0
          Conv2d-300             [-1, 64, 8, 8]          16,448
     BatchNorm2d-301             [-1, 64, 8, 8]             128
            ReLU-302             [-1, 64, 8, 8]               0
          Conv2d-303             [-1, 64, 8, 8]          36,928
     BatchNorm2d-304             [-1, 64, 8, 8]             128
            ReLU-305             [-1, 64, 8, 8]               0
          Conv2d-306            [-1, 256, 8, 8]          16,640
     BatchNorm2d-307            [-1, 256, 8, 8]             512
            ReLU-308            [-1, 256, 8, 8]               0
      Bottleneck-309            [-1, 256, 8, 8]               0
          Conv2d-310             [-1, 64, 8, 8]          16,448
     BatchNorm2d-311             [-1, 64, 8, 8]             128
            ReLU-312             [-1, 64, 8, 8]               0
          Conv2d-313             [-1, 64, 8, 8]          36,928
     BatchNorm2d-314             [-1, 64, 8, 8]             128
            ReLU-315             [-1, 64, 8, 8]               0
          Conv2d-316            [-1, 256, 8, 8]          16,640
     BatchNorm2d-317            [-1, 256, 8, 8]             512
            ReLU-318            [-1, 256, 8, 8]               0
      Bottleneck-319            [-1, 256, 8, 8]               0
          Conv2d-320             [-1, 64, 8, 8]          16,448
     BatchNorm2d-321             [-1, 64, 8, 8]             128
            ReLU-322             [-1, 64, 8, 8]               0
          Conv2d-323             [-1, 64, 8, 8]          36,928
     BatchNorm2d-324             [-1, 64, 8, 8]             128
            ReLU-325             [-1, 64, 8, 8]               0
          Conv2d-326            [-1, 256, 8, 8]          16,640
     BatchNorm2d-327            [-1, 256, 8, 8]             512
            ReLU-328            [-1, 256, 8, 8]               0
      Bottleneck-329            [-1, 256, 8, 8]               0
          Conv2d-330             [-1, 64, 8, 8]          16,448
     BatchNorm2d-331             [-1, 64, 8, 8]             128
            ReLU-332             [-1, 64, 8, 8]               0
          Conv2d-333             [-1, 64, 8, 8]          36,928
     BatchNorm2d-334             [-1, 64, 8, 8]             128
            ReLU-335             [-1, 64, 8, 8]               0
          Conv2d-336            [-1, 256, 8, 8]          16,640
     BatchNorm2d-337            [-1, 256, 8, 8]             512
            ReLU-338            [-1, 256, 8, 8]               0
      Bottleneck-339            [-1, 256, 8, 8]               0
          Conv2d-340             [-1, 64, 8, 8]          16,448
     BatchNorm2d-341             [-1, 64, 8, 8]             128
            ReLU-342             [-1, 64, 8, 8]               0
          Conv2d-343             [-1, 64, 8, 8]          36,928
     BatchNorm2d-344             [-1, 64, 8, 8]             128
            ReLU-345             [-1, 64, 8, 8]               0
          Conv2d-346            [-1, 256, 8, 8]          16,640
     BatchNorm2d-347            [-1, 256, 8, 8]             512
            ReLU-348            [-1, 256, 8, 8]               0
      Bottleneck-349            [-1, 256, 8, 8]               0
          Conv2d-350             [-1, 64, 8, 8]          16,448
     BatchNorm2d-351             [-1, 64, 8, 8]             128
            ReLU-352             [-1, 64, 8, 8]               0
          Conv2d-353             [-1, 64, 8, 8]          36,928
     BatchNorm2d-354             [-1, 64, 8, 8]             128
            ReLU-355             [-1, 64, 8, 8]               0
          Conv2d-356            [-1, 256, 8, 8]          16,640
     BatchNorm2d-357            [-1, 256, 8, 8]             512
            ReLU-358            [-1, 256, 8, 8]               0
      Bottleneck-359            [-1, 256, 8, 8]               0
          Conv2d-360             [-1, 64, 8, 8]          16,448
     BatchNorm2d-361             [-1, 64, 8, 8]             128
            ReLU-362             [-1, 64, 8, 8]               0
          Conv2d-363             [-1, 64, 8, 8]          36,928
     BatchNorm2d-364             [-1, 64, 8, 8]             128
            ReLU-365             [-1, 64, 8, 8]               0
          Conv2d-366            [-1, 256, 8, 8]          16,640
     BatchNorm2d-367            [-1, 256, 8, 8]             512
            ReLU-368            [-1, 256, 8, 8]               0
      Bottleneck-369            [-1, 256, 8, 8]               0
          Conv2d-370             [-1, 64, 8, 8]          16,448
     BatchNorm2d-371             [-1, 64, 8, 8]             128
            ReLU-372             [-1, 64, 8, 8]               0
          Conv2d-373             [-1, 64, 8, 8]          36,928
     BatchNorm2d-374             [-1, 64, 8, 8]             128
            ReLU-375             [-1, 64, 8, 8]               0
          Conv2d-376            [-1, 256, 8, 8]          16,640
     BatchNorm2d-377            [-1, 256, 8, 8]             512
            ReLU-378            [-1, 256, 8, 8]               0
      Bottleneck-379            [-1, 256, 8, 8]               0
          Conv2d-380             [-1, 64, 8, 8]          16,448
     BatchNorm2d-381             [-1, 64, 8, 8]             128
            ReLU-382             [-1, 64, 8, 8]               0
          Conv2d-383             [-1, 64, 8, 8]          36,928
     BatchNorm2d-384             [-1, 64, 8, 8]             128
            ReLU-385             [-1, 64, 8, 8]               0
          Conv2d-386            [-1, 256, 8, 8]          16,640
     BatchNorm2d-387            [-1, 256, 8, 8]             512
            ReLU-388            [-1, 256, 8, 8]               0
      Bottleneck-389            [-1, 256, 8, 8]               0
          Conv2d-390             [-1, 64, 8, 8]          16,448
     BatchNorm2d-391             [-1, 64, 8, 8]             128
            ReLU-392             [-1, 64, 8, 8]               0
          Conv2d-393             [-1, 64, 8, 8]          36,928
     BatchNorm2d-394             [-1, 64, 8, 8]             128
            ReLU-395             [-1, 64, 8, 8]               0
          Conv2d-396            [-1, 256, 8, 8]          16,640
     BatchNorm2d-397            [-1, 256, 8, 8]             512
            ReLU-398            [-1, 256, 8, 8]               0
      Bottleneck-399            [-1, 256, 8, 8]               0
          Conv2d-400             [-1, 64, 8, 8]          16,448
     BatchNorm2d-401             [-1, 64, 8, 8]             128
            ReLU-402             [-1, 64, 8, 8]               0
          Conv2d-403             [-1, 64, 8, 8]          36,928
     BatchNorm2d-404             [-1, 64, 8, 8]             128
            ReLU-405             [-1, 64, 8, 8]               0
          Conv2d-406            [-1, 256, 8, 8]          16,640
     BatchNorm2d-407            [-1, 256, 8, 8]             512
            ReLU-408            [-1, 256, 8, 8]               0
      Bottleneck-409            [-1, 256, 8, 8]               0
          Conv2d-410             [-1, 64, 8, 8]          16,448
     BatchNorm2d-411             [-1, 64, 8, 8]             128
            ReLU-412             [-1, 64, 8, 8]               0
          Conv2d-413             [-1, 64, 8, 8]          36,928
     BatchNorm2d-414             [-1, 64, 8, 8]             128
            ReLU-415             [-1, 64, 8, 8]               0
          Conv2d-416            [-1, 256, 8, 8]          16,640
     BatchNorm2d-417            [-1, 256, 8, 8]             512
            ReLU-418            [-1, 256, 8, 8]               0
      Bottleneck-419            [-1, 256, 8, 8]               0
          Conv2d-420             [-1, 64, 8, 8]          16,448
     BatchNorm2d-421             [-1, 64, 8, 8]             128
            ReLU-422             [-1, 64, 8, 8]               0
          Conv2d-423             [-1, 64, 8, 8]          36,928
     BatchNorm2d-424             [-1, 64, 8, 8]             128
            ReLU-425             [-1, 64, 8, 8]               0
          Conv2d-426            [-1, 256, 8, 8]          16,640
     BatchNorm2d-427            [-1, 256, 8, 8]             512
            ReLU-428            [-1, 256, 8, 8]               0
      Bottleneck-429            [-1, 256, 8, 8]               0
          Conv2d-430             [-1, 64, 8, 8]          16,448
     BatchNorm2d-431             [-1, 64, 8, 8]             128
            ReLU-432             [-1, 64, 8, 8]               0
          Conv2d-433             [-1, 64, 8, 8]          36,928
     BatchNorm2d-434             [-1, 64, 8, 8]             128
            ReLU-435             [-1, 64, 8, 8]               0
          Conv2d-436            [-1, 256, 8, 8]          16,640
     BatchNorm2d-437            [-1, 256, 8, 8]             512
            ReLU-438            [-1, 256, 8, 8]               0
      Bottleneck-439            [-1, 256, 8, 8]               0
          Conv2d-440             [-1, 64, 8, 8]          16,448
     BatchNorm2d-441             [-1, 64, 8, 8]             128
            ReLU-442             [-1, 64, 8, 8]               0
          Conv2d-443             [-1, 64, 8, 8]          36,928
     BatchNorm2d-444             [-1, 64, 8, 8]             128
            ReLU-445             [-1, 64, 8, 8]               0
          Conv2d-446            [-1, 256, 8, 8]          16,640
     BatchNorm2d-447            [-1, 256, 8, 8]             512
            ReLU-448            [-1, 256, 8, 8]               0
      Bottleneck-449            [-1, 256, 8, 8]               0
          Conv2d-450             [-1, 64, 8, 8]          16,448
     BatchNorm2d-451             [-1, 64, 8, 8]             128
            ReLU-452             [-1, 64, 8, 8]               0
          Conv2d-453             [-1, 64, 8, 8]          36,928
     BatchNorm2d-454             [-1, 64, 8, 8]             128
            ReLU-455             [-1, 64, 8, 8]               0
          Conv2d-456            [-1, 256, 8, 8]          16,640
     BatchNorm2d-457            [-1, 256, 8, 8]             512
            ReLU-458            [-1, 256, 8, 8]               0
      Bottleneck-459            [-1, 256, 8, 8]               0
          Conv2d-460             [-1, 64, 8, 8]          16,448
     BatchNorm2d-461             [-1, 64, 8, 8]             128
            ReLU-462             [-1, 64, 8, 8]               0
          Conv2d-463             [-1, 64, 8, 8]          36,928
     BatchNorm2d-464             [-1, 64, 8, 8]             128
            ReLU-465             [-1, 64, 8, 8]               0
          Conv2d-466            [-1, 256, 8, 8]          16,640
     BatchNorm2d-467            [-1, 256, 8, 8]             512
            ReLU-468            [-1, 256, 8, 8]               0
      Bottleneck-469            [-1, 256, 8, 8]               0
          Conv2d-470             [-1, 64, 8, 8]          16,448
     BatchNorm2d-471             [-1, 64, 8, 8]             128
            ReLU-472             [-1, 64, 8, 8]               0
          Conv2d-473             [-1, 64, 8, 8]          36,928
     BatchNorm2d-474             [-1, 64, 8, 8]             128
            ReLU-475             [-1, 64, 8, 8]               0
          Conv2d-476            [-1, 256, 8, 8]          16,640
     BatchNorm2d-477            [-1, 256, 8, 8]             512
            ReLU-478            [-1, 256, 8, 8]               0
      Bottleneck-479            [-1, 256, 8, 8]               0
AdaptiveAvgPool2d-480            [-1, 256, 1, 1]               0
         Flatten-481                  [-1, 256]               0
          Linear-482                   [-1, 10]           2,570
================================================================
Total params: 2,741,386
Trainable params: 2,741,386
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 46.97
Params size (MB): 10.46
Estimated Total Size (MB): 57.44
----------------------------------------------------------------

In [25]:
summary(WideResNet(), input_size=(3,32,32))


| Wide-Resnet 28x2
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 16, 32, 32]             448
       BatchNorm2d-2           [-1, 16, 32, 32]              32
              ReLU-3           [-1, 16, 32, 32]               0
            Conv2d-4           [-1, 32, 32, 32]           4,640
           Dropout-5           [-1, 32, 32, 32]               0
       BatchNorm2d-6           [-1, 32, 32, 32]              64
              ReLU-7           [-1, 32, 32, 32]               0
            Conv2d-8           [-1, 32, 32, 32]           9,248
            Conv2d-9           [-1, 32, 32, 32]             544
        WideBasic-10           [-1, 32, 32, 32]               0
      BatchNorm2d-11           [-1, 32, 32, 32]              64
             ReLU-12           [-1, 32, 32, 32]               0
           Conv2d-13           [-1, 32, 32, 32]           9,248
          Dropout-14           [-1, 32, 32, 32]               0
      BatchNorm2d-15           [-1, 32, 32, 32]              64
             ReLU-16           [-1, 32, 32, 32]               0
           Conv2d-17           [-1, 32, 32, 32]           9,248
        WideBasic-18           [-1, 32, 32, 32]               0
      BatchNorm2d-19           [-1, 32, 32, 32]              64
             ReLU-20           [-1, 32, 32, 32]               0
           Conv2d-21           [-1, 32, 32, 32]           9,248
          Dropout-22           [-1, 32, 32, 32]               0
      BatchNorm2d-23           [-1, 32, 32, 32]              64
             ReLU-24           [-1, 32, 32, 32]               0
           Conv2d-25           [-1, 32, 32, 32]           9,248
        WideBasic-26           [-1, 32, 32, 32]               0
      BatchNorm2d-27           [-1, 32, 32, 32]              64
             ReLU-28           [-1, 32, 32, 32]               0
           Conv2d-29           [-1, 32, 32, 32]           9,248
          Dropout-30           [-1, 32, 32, 32]               0
      BatchNorm2d-31           [-1, 32, 32, 32]              64
             ReLU-32           [-1, 32, 32, 32]               0
           Conv2d-33           [-1, 32, 32, 32]           9,248
        WideBasic-34           [-1, 32, 32, 32]               0
      BatchNorm2d-35           [-1, 32, 32, 32]              64
             ReLU-36           [-1, 32, 32, 32]               0
           Conv2d-37           [-1, 64, 32, 32]          18,496
          Dropout-38           [-1, 64, 32, 32]               0
      BatchNorm2d-39           [-1, 64, 32, 32]             128
             ReLU-40           [-1, 64, 32, 32]               0
           Conv2d-41           [-1, 64, 16, 16]          36,928
           Conv2d-42           [-1, 64, 16, 16]           2,112
        WideBasic-43           [-1, 64, 16, 16]               0
      BatchNorm2d-44           [-1, 64, 16, 16]             128
             ReLU-45           [-1, 64, 16, 16]               0
           Conv2d-46           [-1, 64, 16, 16]          36,928
          Dropout-47           [-1, 64, 16, 16]               0
      BatchNorm2d-48           [-1, 64, 16, 16]             128
             ReLU-49           [-1, 64, 16, 16]               0
           Conv2d-50           [-1, 64, 16, 16]          36,928
        WideBasic-51           [-1, 64, 16, 16]               0
      BatchNorm2d-52           [-1, 64, 16, 16]             128
             ReLU-53           [-1, 64, 16, 16]               0
           Conv2d-54           [-1, 64, 16, 16]          36,928
          Dropout-55           [-1, 64, 16, 16]               0
      BatchNorm2d-56           [-1, 64, 16, 16]             128
             ReLU-57           [-1, 64, 16, 16]               0
           Conv2d-58           [-1, 64, 16, 16]          36,928
        WideBasic-59           [-1, 64, 16, 16]               0
      BatchNorm2d-60           [-1, 64, 16, 16]             128
             ReLU-61           [-1, 64, 16, 16]               0
           Conv2d-62           [-1, 64, 16, 16]          36,928
          Dropout-63           [-1, 64, 16, 16]               0
      BatchNorm2d-64           [-1, 64, 16, 16]             128
             ReLU-65           [-1, 64, 16, 16]               0
           Conv2d-66           [-1, 64, 16, 16]          36,928
        WideBasic-67           [-1, 64, 16, 16]               0
      BatchNorm2d-68           [-1, 64, 16, 16]             128
             ReLU-69           [-1, 64, 16, 16]               0
           Conv2d-70          [-1, 128, 16, 16]          73,856
          Dropout-71          [-1, 128, 16, 16]               0
      BatchNorm2d-72          [-1, 128, 16, 16]             256
             ReLU-73          [-1, 128, 16, 16]               0
           Conv2d-74            [-1, 128, 8, 8]         147,584
           Conv2d-75            [-1, 128, 8, 8]           8,320
        WideBasic-76            [-1, 128, 8, 8]               0
      BatchNorm2d-77            [-1, 128, 8, 8]             256
             ReLU-78            [-1, 128, 8, 8]               0
           Conv2d-79            [-1, 128, 8, 8]         147,584
          Dropout-80            [-1, 128, 8, 8]               0
      BatchNorm2d-81            [-1, 128, 8, 8]             256
             ReLU-82            [-1, 128, 8, 8]               0
           Conv2d-83            [-1, 128, 8, 8]         147,584
        WideBasic-84            [-1, 128, 8, 8]               0
      BatchNorm2d-85            [-1, 128, 8, 8]             256
             ReLU-86            [-1, 128, 8, 8]               0
           Conv2d-87            [-1, 128, 8, 8]         147,584
          Dropout-88            [-1, 128, 8, 8]               0
      BatchNorm2d-89            [-1, 128, 8, 8]             256
             ReLU-90            [-1, 128, 8, 8]               0
           Conv2d-91            [-1, 128, 8, 8]         147,584
        WideBasic-92            [-1, 128, 8, 8]               0
      BatchNorm2d-93            [-1, 128, 8, 8]             256
             ReLU-94            [-1, 128, 8, 8]               0
           Conv2d-95            [-1, 128, 8, 8]         147,584
          Dropout-96            [-1, 128, 8, 8]               0
      BatchNorm2d-97            [-1, 128, 8, 8]             256
             ReLU-98            [-1, 128, 8, 8]               0
           Conv2d-99            [-1, 128, 8, 8]         147,584
       WideBasic-100            [-1, 128, 8, 8]               0
     BatchNorm2d-101            [-1, 128, 8, 8]             256
            ReLU-102            [-1, 128, 8, 8]               0
AdaptiveAvgPool2d-103            [-1, 128, 1, 1]               0
         Flatten-104                  [-1, 128]               0
          Linear-105                   [-1, 10]           1,290
================================================================
Total params: 1,469,642
Trainable params: 1,469,642
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 17.06
Params size (MB): 5.61
Estimated Total Size (MB): 22.68
----------------------------------------------------------------

In [ ]: