``````

In [3]:

import itertools
import math

import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim

import numpy as np

``````

# Sparse CNN

The idea of the sparse CNN is to utilize only the important values of its filters and set those less impotant to zero.

For terminology, we'll call the weights of a specific channel a "filter", and we'll call the values of each filter a "connection". Via Hebian inspired learning, those connections that see high absolute-values in both their inputs and outputs will be favored. This commparison of input and output will be refered to as the connection's strength.

# Connection Strength Calculations

Initialize a typical conv layer to start. From here, three seperate methods will be attempted to calculate the strength of the connections.

``````

In [4]:

c_out = 7
c_in = 8
conv = torch.nn.Conv2d(c_in, c_out, kernel_size=(2, 2), stride=(1, 1), padding=0, dilation=1, groups=1)
input_tensor = torch.randn(2, c_in, 5, 3)
output_tensor = conv(input_tensor)

``````
``````

In [5]:

# Get params of conv layer.
in_channels = conv.in_channels  # Number of channels in the input image
out_channels = conv.out_channels  # Number of channels produced by the convolution
kernel_size = conv.kernel_size  # Size of the convolving kernel
stride = conv.stride  # Stride of the convolution. Default: 1
dilation = conv.dilation  # Spacing between kernel elements. Default: 1
groups = conv.groups  # Number of blocked connections from input channels to output channels. Default: 1
bias = conv.bias is not None

``````

## Vectorized Method

The idea of this method is to utilize convolutional arithmetic to determine the input for a given output unit and a given connection.

Suppose we initialize a weight matrix of exactly the same dimensions of our orignal conv layer, and set all of filters to 0 except for one connection. That is,

``new_conv.weight[:, c, j, h] = 1.``

Now if we pass the input through `new_conv`, we'll be given an output tensor of the same size as the original, but with the input values arranged at the locations of their respective output through the connection. That is,

``````old_output = conv[input]
new_output = new_conv[input]

# ==> for all b, j,  and h (b being the batch), we have
# new_output[b, :, j, h] = input[<indices of input passed through connection conv.weight[:, c, j, h]>]

examine_connections(old_output, new_output) # done in pointwise fashion``````

With this vectorized calculation, we may then loop over all combinations of `c`, `j`, and `h`, compare the outputs to their respective inputs, and populate a matrix to record the strengths.

``````

In [6]:

def get_single_unit_conv(c, j, h, **kwargs):
"""
Constructs and returns conv layer with trainging disabled and
all zero weights except along the output channels for unit
specified as (c, j, h).
"""

# Construct conv.
conv = torch.nn.Conv2d(**kwargs)

# Turn off training.
conv.train = False

# Set weights to zero except those specified.
conv.weight.set_(torch.zeros_like(conv.weight))
conv.weight[:, c, j, h] = 1

return conv

# Get inidices that loop over all connections.
single_unit_indxs = list(itertools.product(*[range(d) for d in conv.weight.shape[1:]]))

single_unit_convs = [
get_single_unit_conv(
c, j, h,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
groups=groups,
bias=False,
)
for c, j, h in single_unit_indxs
]

``````
``````

In [7]:

def f1():
"""
Calculate connection strengths.
"""
H = torch.zeros_like(conv.weight)
s1 = torch.sigmoid(output_tensor).gt_(0.5)
for idx, sconv in zip(single_unit_indxs, single_unit_convs):

s2 = torch.sigmoid(sconv(input_tensor)).gt_(0.5)
m = torch.sum(s1.mul(s2), (0, 2, 3,))

H[:, idx[0], idx[1], idx[2]] += m

return H

``````

## Vectorized Method + Grouping

Same as the previous method, but utilizing the grouping argument of the conv layer so that only one is needed

``````

In [8]:

def get_single_unit_weights(shape, c, j, h, **kwargs):
"""
Constructs and returns conv layer with traingin diabled and
all zero weights except along the output channels for unit
specified as (c, j, h).
"""

# Construct weight.
weight = torch.zeros(shape)

# Set weights to zero except those specified.
weight[:, c, j, h] = 1

return weight

# Compute inidices that loop over all connections of a channel.
filter_indxs = list(itertools.product(*[range(d) for d in conv.weight.shape[1:]]))

# Compute indeces that loop over all channels and filters.
# This will be used to unpack the pointwise comparisons of the output.
connection_indxs = []
for idx in filter_indxs:
i_ = list(idx)
connection_indxs.extend([
[c_]+i_ for c_ in range(out_channels)
])
connection_indxs = list(zip(*connection_indxs))

# Create new conv layer that groups it's input and output.
new_groups = len(filter_indxs)
stacked_conv = torch.nn.Conv2d(
in_channels=in_channels * new_groups,
out_channels=out_channels * new_groups,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
groups=groups * new_groups,
bias=False,
)

# Populate the weight matrix with stacked tensors having only one non-zero unit.
single_unit_weights = [
get_single_unit_weights(
conv.weight.shape,
c, j, h,
)
for c, j, h in filter_indxs
]
stacked_conv.weight.set_(torch.cat(single_unit_weights, dim=0))

``````
``````

In [9]:

def f2():
#     print('------f2a--------')
#     print('input ', input_tensor.shape)
stacked_input = input_tensor.repeat((1, new_groups, 1, 1))
stacked_output = stacked_conv(stacked_input)

#     print('stacked_input ', stacked_input.shape)
#     print('stacked_output', stacked_output.shape)

H = torch.zeros_like(conv.weight)

s1 = torch.sigmoid(stacked_output).gt_(0.5)
s2 = torch.sigmoid(output_tensor).gt_(0.5).repeat((1, new_groups, 1, 1))

print('s1', s1.shape)
print('s2', s2.shape)

H_ = torch.sum(s2.mul(s1), (0, 2, 3,))

#     print('H_', H_.shape)

H[connection_indxs] = H_

#     print('\n')
return H

``````

# Vectorized Method with Less Redundancies

``````

In [10]:

def get_single_unit_weights_2b(shape, c, j, h, **kwargs):
"""
Constructs and returns conv layer with traingin diabled and
all zero weights except along the output channels for unit
specified as (c, j, h).
"""

# Construct weight.
weight = torch.zeros(1, *shape[1:])

# Set weights to zero except those specified.
weight[0, c, j, h] = 1

return weight

# Compute inidices that loop over all connections of a channel.
filter_indxs_2b = list(itertools.product(*[range(d) for d in conv.weight.shape[1:]]))

# Compute indeces that loop over all channels and filters.
# This will be used to unpack the pointwise comparisons of the output.
connection_indxs_2b = []
for c_ in range(out_channels):
for idx in filter_indxs_2b:
i_ = list(idx)
connection_indxs_2b.append([c_] + i_)
connection_indxs_2b = list(zip(*connection_indxs_2b))

new_groups_2b = int(np.prod(conv.weight.shape[1:]))
perm_indices_2b = []
for c_i in range(out_channels):
perm_indices_2b.extend(
[c_i] * new_groups_2b
)

# Create new conv layer that groups it's input and output.
stacked_conv_2b = torch.nn.Conv2d(
in_channels=in_channels * new_groups_2b,
out_channels=new_groups_2b,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
groups=groups * new_groups_2b,
bias=False,
)
# Populate the weight matrix with stacked tensors having only one non-zero unit.
single_unit_weights_2b = [
get_single_unit_weights_2b(
conv.weight.shape,
c, j, h,
)
for c, j, h in filter_indxs_2b
]
stacked_conv_2b.weight.set_(torch.cat(single_unit_weights_2b, dim=0))

``````
``````

In [11]:

def f2b():

stacked_input = input_tensor.repeat((1, new_groups_2b, 1, 1))
stacked_output = stacked_conv_2b(stacked_input).repeat((1, out_channels, 1, 1))

H = torch.zeros_like(conv.weight)

s1 = stacked_output
s2 = output_tensor[:, perm_indices_2b, ...]

mu_in = s1.mean(dim=0)
mu_out = s2.mean(dim=0)

std_in = s1.std(dim=0)
std_out = s2.std(dim=0)

corr = ((s1 - mu_in) * (s2 - mu_out)).mean(dim=0) / (std_in * std_out)

corr[torch.where((std_in == 0 ) | (std_out == 0 ))] = 0
corr = corr.abs()

H_ = torch.mean(corr, (1, 2))

H[connection_indxs_2b] = H_

return H

``````

## Brute Force Method

Computationally speaking, this is the same method as the preivous two. Only now, instead of using conv layers to assist in the computations, we use `for` loops the brute force our way through.

This is more so for a sanity check on the first two to validate their outputs.

``````

In [12]:

def coactivation(t1, t2):
s = (torch.sigmoid(t1) > 0.5) * (torch.sigmoid(t2) > 0.5)
return s

def get_indeces_of_input_and_filter(n, m):
"""
Assumes dilation=1, i.e. typical conv.
"""

k1, k2 = kernel_size
s1, s2 = stride

i1, i2 = (0, 0)

i1 -= p1
i2 -= p2

i1 += n * s1
i2 += m * s2

if i2 == 2:
import ipdb; ipdb.set_trace()

indxs = []
for c_in in range(in_channels):
for n_k1 in range(k1):
for m_k2 in range(k2):
filter_indx = (c_in,      n_k1,      m_k2)
input_indx  = (c_in, i1 + n_k1, i2 + m_k2)
indxs.append((input_indx, filter_indx))

return indxs

``````
``````

In [18]:

B     = output_tensor.shape[0]
N_out = output_tensor.shape[2]
M_out = output_tensor.shape[3]
C_in  = conv.weight.shape[1]
C_out = conv.weight.shape[0]

def f3():
H = torch.zeros_like(conv.weight)
for b in range(B):
for c_out in range(C_out):
for n_out in range(N_out):
for m_out in range(M_out):
unit_1 = output_tensor[b, c_out, n_out, m_out]
indxs  = get_indeces_of_input_and_filter(n_out, m_out)

for input_indx, filter_indx in indxs:
c_in, n_in, m_in = input_indx
c_fl, n_fl, m_fl = filter_indx
unit_2 = input_tensor[b, c_in, n_in, m_in]

if coactivation(unit_1, unit_2):
H[c_out, c_fl, n_fl, m_fl] += 1

return H

``````

## Validation \ Time Trials

Quick test to make sure they all give the same output. Let's see how long the take to run.

``````

In [19]:

assert f1().allclose(f2(), rtol=0, atol=0) and f2().allclose(f2b(), rtol=0, atol=0) and f2b().allclose(f3(), rtol=0, atol=0)
%timeit f1()
%timeit f2()
%timeit f2b()
# %timeit f3()

``````
``````

9.64 ms ± 162 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
326 µs ± 6.49 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
348 µs ± 5.33 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

``````

# Sparse Conv Layer

Now to implement a conv layer that utilizes the second implementation above.

``````

In [14]:

class DSConv2d(torch.nn.Conv2d):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self.learning_iterations = 0

self.activity_threshold = 0.5
self.k1                 = max(int(0.1 * np.prod(self.weight.shape[2:])), 1)
self.k2                 = max(int(0.15 * np.prod(self.weight.shape[2:])), 1)
self.prune_dims         = [0, 1] # Todo: sort

self.connections_tensor = torch.zeros_like(self.weight)

# Compute inidices that loop over all connections of a channel.
filter_indxs = list(itertools.product(*[range(d) for d in self.weight.shape[1:]]))

# Compute indeces that loop over all channels and filters.
# This will be used to unpack the pointwise comparisons of the output.
self.connection_indxs = []
for idx in filter_indxs:
i_ = list(idx)
self.connection_indxs.extend([
[c]+i_ for c in range(self.weight.shape[0])
])
self.connection_indxs = list(zip(*self.connection_indxs))

# Create new conv layer that groups it's input and output.
self.new_groups = len(filter_indxs)
self.stacked_conv = torch.nn.Conv2d(
in_channels=self.in_channels * self.new_groups,
out_channels=self.out_channels * self.new_groups,
kernel_size=self.kernel_size,
stride=self.stride,
dilation=self.dilation,
groups=self.groups * self.new_groups,
bias=False,
)

# Populate the weight matrix with stacked tensors having only one non-zero unit.
single_unit_weights = [
self.get_single_unit_weights(
self.weight.shape,
c, j, h,
)
for c, j, h in filter_indxs
]
self.stacked_conv.weight.set_(torch.cat(single_unit_weights, dim=0))

def get_single_unit_weights(self, shape, c, j, h):
"""
Constructs and returns conv layer with traingin diabled and
all zero weights except along the output channels for unit
specified as (c, j, h).
"""

# Construct weight.
weight = torch.zeros(self.weight.shape)

# Set weights to zero except those specified.
weight[:, c, j, h] = 1

return weight

def update_connections_tensor(self, input_tensor, output_tensor):

stacked_input = input_tensor.repeat((1, self.new_groups, 1, 1))
stacked_output = self.stacked_conv(stacked_input)

s1 = torch.sigmoid(stacked_output).gt_(0.5)
s2 = torch.sigmoid(output_tensor).gt_(0.5).repeat((1, self.new_groups, 1, 1))
H_ = torch.sum(s2.mul(s1), (0, 2, 3,))

self.connections_tensor[self.connection_indxs] = H_

def progress_connections(self):
"""
"""

# Get strengths of all connections.
strengths = self.connections_tensor.numpy()
shape = strengths.shape

# Determine all combinations of prune dimensions
all_dims = range(len(shape))
prune_indxs = [range(shape[d]) if d in self.prune_dims else [slice(None)] for d in all_dims]
prune_indxs = itertools.product(*prune_indxs)

# Along all combinations of prune dimensions:
#    - Keep strongest k1 connections
#    - Reinitilize trailing k2 - k1 connections.
k1 = self.k1
k2 = self.k2
for idx in prune_indxs:

# Get top k1'th strength.
s = strengths[idx].flatten()
v1 = np.partition(s, -k1)[-k1] # s.kthvalue(len(s) - k1).value

# Keep top k1'th connection - prune those below
c = self.weight[idx].flatten()

# Get trailing k2 - k1 connections.
v2 = np.partition(s, -k2)[-k2] # s.kthvalue(len(s) - k2).value

# Reinitilized trailing k2 - k1 connections.
# Note: [None, :] is added here as kaiming_uniform_ requires a 2d tensor.

# Reshape connections and update the weight.
self.weight[idx] = c.reshape(self.weight[idx].shape)

# Reset connection strengths.
self.connections_tensor = torch.zeros_like(self.weight)

def prune_randomly(self):

prune_mask = torch.rand(self.weight.shape) < 0.85 # prune 15% of weights

# Reinitialize those that are zero.
if len(new_weights) > 0:
torch.nn.init.kaiming_uniform_(new_weights[None, :])

def __call__(self, input_tensor, *args, **kwargs):
output_tensor = super().__call__(input_tensor, *args, **kwargs)
if self.learning_iterations % 20 == 0:
self.update_connections_tensor(input_tensor, output_tensor)
self.learning_iterations += 1
return output_tensor

``````

# Test Training a Network

The following is a simple toy example copied mostly from: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html

The main difference here is that the network utilizes the SparseCNN module. This exercise moslty servesto gain confidence in the implementation with respect to it's ability to run without errors - this is not concerned with verifying training improvements just yet.

``````

In [15]:

root_path = '~/nta/datasets'

``````
``````

In [21]:

transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root=root_path, train=True,
shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root=root_path, train=False,
shuffle=False, num_workers=0)

classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

``````
``````

``````

## Time Trials

Quick test to compare runtime with and without updating the connections tensor.

``````

In [22]:

cd = torch.nn.Conv2d(3, 6, 5)
cs = DSConv2d(3, 6, 5)

images, labels = dataiter.next()

print('Dense CNN foward pass:')
%timeit cd(images)
print('DSConv2d foward pass:')
%timeit cs(images)

``````
``````

Dense CNN foward pass:
215 µs ± 3.04 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
DSConv2d foward pass:
3.59 ms ± 37.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

``````

## Network Setup

``````

In [278]:

import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
# self.conv1 = nn.Conv2d(3, 6, 5)
self.conv1 = DSConv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
# self.conv2 = nn.Conv2d(6, 16, 5)
self.conv2 = DSConv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x

net = Net()

``````

## Training

``````

In [279]:

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

for epoch in range(2):  # loop over the dataset multiple times

running_loss = 0.0
for i, data in enumerate(trainloader, 0):

# get the inputs; data is a list of [inputs, labels]
inputs, labels = data

# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

# print statistics
running_loss += loss.item()
if i % 2000 == 1999:    # print every 2000 mini-batches
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0

net.conv1.progress_connections()
net.conv2.progress_connections()

break

break

#     # Compare with pruning random weights.
#     net.conv1.prune_randomly()
#     net.conv2.prune_randomly()

print('Finished Training')

``````
``````

[1,  2000] loss: 2.221
Finished Training

``````

## Testing Accuracy

``````

In [17]:

class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs, 1)
c = (predicted == labels).squeeze()
for i in range(4):
label = labels[i]
class_correct[label] += c[i].item()
class_total[label] += 1

for i in range(10):
print('Accuracy of %5s : %2d %%' % (
classes[i], 100 * class_correct[i] / class_total[i]))

``````
``````

Accuracy of plane :  0 %
Accuracy of   car :  0 %
Accuracy of  bird : 79 %
Accuracy of   cat :  2 %
Accuracy of  deer : 10 %
Accuracy of   dog :  0 %
Accuracy of  frog :  1 %
Accuracy of horse :  0 %
Accuracy of  ship :  0 %
Accuracy of truck :  0 %

``````