In [3]:
import itertools 
import math

import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim

import numpy as np

Sparse CNN

The idea of the sparse CNN is to utilize only the important values of its filters and set those less impotant to zero.

For terminology, we'll call the weights of a specific channel a "filter", and we'll call the values of each filter a "connection". Via Hebian inspired learning, those connections that see high absolute-values in both their inputs and outputs will be favored. This commparison of input and output will be refered to as the connection's strength.

Connection Strength Calculations

Initialize a typical conv layer to start. From here, three seperate methods will be attempted to calculate the strength of the connections.


In [4]:
c_out = 7
c_in = 8
conv = torch.nn.Conv2d(c_in, c_out, kernel_size=(2, 2), stride=(1, 1), padding=0, dilation=1, groups=1)
input_tensor = torch.randn(2, c_in, 5, 3)
output_tensor = conv(input_tensor)

In [5]:
# Get params of conv layer.
in_channels = conv.in_channels  # Number of channels in the input image
out_channels = conv.out_channels  # Number of channels produced by the convolution
kernel_size = conv.kernel_size  # Size of the convolving kernel
stride = conv.stride  # Stride of the convolution. Default: 1
padding = conv.padding  # Zero-padding added to both sides of the input. Default: 0
padding_mode = conv.padding_mode  # zeros
dilation = conv.dilation  # Spacing between kernel elements. Default: 1
groups = conv.groups  # Number of blocked connections from input channels to output channels. Default: 1
bias = conv.bias is not None

Vectorized Method

The idea of this method is to utilize convolutional arithmetic to determine the input for a given output unit and a given connection.

Suppose we initialize a weight matrix of exactly the same dimensions of our orignal conv layer, and set all of filters to 0 except for one connection. That is,

new_conv.weight[:, c, j, h] = 1.

Now if we pass the input through new_conv, we'll be given an output tensor of the same size as the original, but with the input values arranged at the locations of their respective output through the connection. That is,

old_output = conv[input]
new_output = new_conv[input]

# ==> for all b, j,  and h (b being the batch), we have 
# new_output[b, :, j, h] = input[<indices of input passed through connection conv.weight[:, c, j, h]>]

examine_connections(old_output, new_output) # done in pointwise fashion

With this vectorized calculation, we may then loop over all combinations of c, j, and h, compare the outputs to their respective inputs, and populate a matrix to record the strengths.


In [6]:
def get_single_unit_conv(c, j, h, **kwargs):
    """
    Constructs and returns conv layer with trainging disabled and
    all zero weights except along the output channels for unit
    specified as (c, j, h).
    """
    
    # Construct conv.
    conv = torch.nn.Conv2d(**kwargs)
    
    # Turn off training.
    conv.train = False
    
    # Set weights to zero except those specified.
    with torch.no_grad():
        conv.weight.set_(torch.zeros_like(conv.weight))
        conv.weight[:, c, j, h] = 1
        
    return conv

# Get inidices that loop over all connections.
single_unit_indxs = list(itertools.product(*[range(d) for d in conv.weight.shape[1:]]))

single_unit_convs = [
    get_single_unit_conv(
        c, j, h,
        in_channels=in_channels,
        out_channels=out_channels,
        kernel_size=kernel_size,
        stride=stride,
        padding=padding,
        padding_mode=padding_mode,
        dilation=dilation,
        groups=groups,
        bias=False,
    )
    for c, j, h in single_unit_indxs
]

In [7]:
def f1():
    """
    Calculate connection strengths.
    """
    H = torch.zeros_like(conv.weight)
    s1 = torch.sigmoid(output_tensor).gt_(0.5)
    for idx, sconv in zip(single_unit_indxs, single_unit_convs):

        s2 = torch.sigmoid(sconv(input_tensor)).gt_(0.5)
        m = torch.sum(s1.mul(s2), (0, 2, 3,))

        H[:, idx[0], idx[1], idx[2]] += m
        
    return H

Vectorized Method + Grouping

Same as the previous method, but utilizing the grouping argument of the conv layer so that only one is needed


In [8]:
def get_single_unit_weights(shape, c, j, h, **kwargs):
    """
    Constructs and returns conv layer with traingin diabled and
    all zero weights except along the output channels for unit
    specified as (c, j, h).
    """
    
    # Construct weight.
    weight = torch.zeros(shape)
    
    # Set weights to zero except those specified.
    weight[:, c, j, h] = 1
        
    return weight

# Compute inidices that loop over all connections of a channel.
filter_indxs = list(itertools.product(*[range(d) for d in conv.weight.shape[1:]]))

# Compute indeces that loop over all channels and filters.
# This will be used to unpack the pointwise comparisons of the output.
connection_indxs = []
for idx in filter_indxs:
    i_ = list(idx)
    connection_indxs.extend([
        [c_]+i_ for c_ in range(out_channels)
    ])
connection_indxs = list(zip(*connection_indxs))

# Create new conv layer that groups it's input and output.
new_groups = len(filter_indxs)
stacked_conv = torch.nn.Conv2d(
    in_channels=in_channels * new_groups,
    out_channels=out_channels * new_groups,
    kernel_size=kernel_size,
    stride=stride,
    padding=padding,
    padding_mode=padding_mode,
    dilation=dilation,
    groups=groups * new_groups,
    bias=False,
)

# Populate the weight matrix with stacked tensors having only one non-zero unit.
single_unit_weights = [
    get_single_unit_weights(
        conv.weight.shape,
        c, j, h,
    )
    for c, j, h in filter_indxs
]
with torch.no_grad():
    stacked_conv.weight.set_(torch.cat(single_unit_weights, dim=0))

In [9]:
def f2():
#     print('------f2a--------')
#     print('input ', input_tensor.shape)
    stacked_input = input_tensor.repeat((1, new_groups, 1, 1))
    stacked_output = stacked_conv(stacked_input)
    
#     print('stacked_input ', stacked_input.shape)
#     print('stacked_output', stacked_output.shape)
    
    H = torch.zeros_like(conv.weight)

    s1 = torch.sigmoid(stacked_output).gt_(0.5)
    s2 = torch.sigmoid(output_tensor).gt_(0.5).repeat((1, new_groups, 1, 1))
    
    print('s1', s1.shape)
    print('s2', s2.shape)
    
    H_ = torch.sum(s2.mul(s1), (0, 2, 3,))
    
#     print('H_', H_.shape)

    H[connection_indxs] = H_
    
#     print('\n')
    return H

Vectorized Method with Less Redundancies


In [10]:
def get_single_unit_weights_2b(shape, c, j, h, **kwargs):
    """
    Constructs and returns conv layer with traingin diabled and
    all zero weights except along the output channels for unit
    specified as (c, j, h).
    """
    
    # Construct weight.
    weight = torch.zeros(1, *shape[1:])
    
    # Set weights to zero except those specified.
    weight[0, c, j, h] = 1
        
    return weight

# Compute inidices that loop over all connections of a channel.
filter_indxs_2b = list(itertools.product(*[range(d) for d in conv.weight.shape[1:]]))

# Compute indeces that loop over all channels and filters.
# This will be used to unpack the pointwise comparisons of the output.
connection_indxs_2b = []
for c_ in range(out_channels):
    for idx in filter_indxs_2b:
        i_ = list(idx)
        connection_indxs_2b.append([c_] + i_)
connection_indxs_2b = list(zip(*connection_indxs_2b))

new_groups_2b = int(np.prod(conv.weight.shape[1:]))
perm_indices_2b = []
for c_i in range(out_channels):
    perm_indices_2b.extend(
        [c_i] * new_groups_2b
    )

# Create new conv layer that groups it's input and output.
stacked_conv_2b = torch.nn.Conv2d(
    in_channels=in_channels * new_groups_2b,
    out_channels=new_groups_2b,
    kernel_size=kernel_size,
    stride=stride,
    padding=padding,
    padding_mode=padding_mode,
    dilation=dilation,
    groups=groups * new_groups_2b,
    bias=False,
)
# Populate the weight matrix with stacked tensors having only one non-zero unit.
single_unit_weights_2b = [
    get_single_unit_weights_2b(
        conv.weight.shape,
        c, j, h,
    )
    for c, j, h in filter_indxs_2b
]
with torch.no_grad():
    stacked_conv_2b.weight.set_(torch.cat(single_unit_weights_2b, dim=0))

In [11]:
def f2b():
    
    stacked_input = input_tensor.repeat((1, new_groups_2b, 1, 1))
    stacked_output = stacked_conv_2b(stacked_input).repeat((1, out_channels, 1, 1))
    

    H = torch.zeros_like(conv.weight)

    s1 = stacked_output
    s2 = output_tensor[:, perm_indices_2b, ...]

    mu_in = s1.mean(dim=0)
    mu_out = s2.mean(dim=0)

    std_in = s1.std(dim=0)
    std_out = s2.std(dim=0)
    
    corr = ((s1 - mu_in) * (s2 - mu_out)).mean(dim=0) / (std_in * std_out)
    
    corr[torch.where((std_in == 0 ) | (std_out == 0 ))] = 0
    corr = corr.abs()
    
    H_ = torch.mean(corr, (1, 2))

    H[connection_indxs_2b] = H_
    
    return H

Brute Force Method

Computationally speaking, this is the same method as the preivous two. Only now, instead of using conv layers to assist in the computations, we use for loops the brute force our way through.

This is more so for a sanity check on the first two to validate their outputs.


In [12]:
def coactivation(t1, t2):
    s = (torch.sigmoid(t1) > 0.5) * (torch.sigmoid(t2) > 0.5)
    return s 
    
def get_indeces_of_input_and_filter(n, m):
    """
    Assumes dilation=1, i.e. typical conv.
    """
    
    k1, k2 = kernel_size
    p1, p2 = padding
    s1, s2 = stride
    
    i1, i2 = (0, 0)
    
    i1 -= p1
    i2 -= p2
    
    i1 += n * s1
    i2 += m * s2
    
    if i2 == 2:
        import ipdb; ipdb.set_trace()
    
    indxs = []
    for c_in in range(in_channels):
        for n_k1 in range(k1):
            for m_k2 in range(k2):
                filter_indx = (c_in,      n_k1,      m_k2)
                input_indx  = (c_in, i1 + n_k1, i2 + m_k2)
                indxs.append((input_indx, filter_indx))
                
    return indxs

In [18]:
B     = output_tensor.shape[0]
N_out = output_tensor.shape[2]
M_out = output_tensor.shape[3]
C_in  = conv.weight.shape[1]
C_out = conv.weight.shape[0]

def f3():
    H = torch.zeros_like(conv.weight)
    for b in range(B):
        for c_out in range(C_out):
            for n_out in range(N_out):
                for m_out in range(M_out):
                    unit_1 = output_tensor[b, c_out, n_out, m_out]
                    indxs  = get_indeces_of_input_and_filter(n_out, m_out)

                    for input_indx, filter_indx in indxs:
                        c_in, n_in, m_in = input_indx
                        c_fl, n_fl, m_fl = filter_indx
                        unit_2 = input_tensor[b, c_in, n_in, m_in]

                        if coactivation(unit_1, unit_2):
                            H[c_out, c_fl, n_fl, m_fl] += 1
                            
    return H

Validation \ Time Trials

Quick test to make sure they all give the same output. Let's see how long the take to run.


In [19]:
assert f1().allclose(f2(), rtol=0, atol=0) and f2().allclose(f2b(), rtol=0, atol=0) and f2b().allclose(f3(), rtol=0, atol=0) 
%timeit f1()
%timeit f2()
%timeit f2b()
# %timeit f3()


9.64 ms ± 162 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
326 µs ± 6.49 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
348 µs ± 5.33 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Sparse Conv Layer

Now to implement a conv layer that utilizes the second implementation above.


In [14]:
class DSConv2d(torch.nn.Conv2d):
    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.learning_iterations = 0
        
        self.activity_threshold = 0.5
        self.k1                 = max(int(0.1 * np.prod(self.weight.shape[2:])), 1)
        self.k2                 = max(int(0.15 * np.prod(self.weight.shape[2:])), 1)
        self.prune_dims         = [0, 1] # Todo: sort
        
        self.connections_tensor = torch.zeros_like(self.weight)
        self.prune_mask = torch.ones_like(self.weight)
        
        # Compute inidices that loop over all connections of a channel.
        filter_indxs = list(itertools.product(*[range(d) for d in self.weight.shape[1:]]))

        # Compute indeces that loop over all channels and filters.
        # This will be used to unpack the pointwise comparisons of the output.
        self.connection_indxs = []
        for idx in filter_indxs:
            i_ = list(idx)
            self.connection_indxs.extend([
                [c]+i_ for c in range(self.weight.shape[0])
            ])
        self.connection_indxs = list(zip(*self.connection_indxs))

        # Create new conv layer that groups it's input and output.
        self.new_groups = len(filter_indxs)
        self.stacked_conv = torch.nn.Conv2d(
            in_channels=self.in_channels * self.new_groups,
            out_channels=self.out_channels * self.new_groups,
            kernel_size=self.kernel_size,
            stride=self.stride,
            padding=self.padding,
            padding_mode=self.padding_mode,
            dilation=self.dilation,
            groups=self.groups * self.new_groups,
            bias=False,
        )

        # Populate the weight matrix with stacked tensors having only one non-zero unit.
        single_unit_weights = [
            self.get_single_unit_weights(
                self.weight.shape,
                c, j, h,
            )
            for c, j, h in filter_indxs
        ]
        with torch.no_grad():
            self.stacked_conv.weight.set_(torch.cat(single_unit_weights, dim=0))    
    
    def get_single_unit_weights(self, shape, c, j, h):
        """
        Constructs and returns conv layer with traingin diabled and
        all zero weights except along the output channels for unit
        specified as (c, j, h).
        """

        # Construct weight.
        weight = torch.zeros(self.weight.shape)

        # Set weights to zero except those specified.
        weight[:, c, j, h] = 1

        return weight
    
    def update_connections_tensor(self, input_tensor, output_tensor):
        
        with torch.no_grad():
            stacked_input = input_tensor.repeat((1, self.new_groups, 1, 1))
            stacked_output = self.stacked_conv(stacked_input)

            s1 = torch.sigmoid(stacked_output).gt_(0.5)
            s2 = torch.sigmoid(output_tensor).gt_(0.5).repeat((1, self.new_groups, 1, 1))
            H_ = torch.sum(s2.mul(s1), (0, 2, 3,))

            self.connections_tensor[self.connection_indxs] = H_
    
    def progress_connections(self):
        """
        Prunes and add connections.
        """
        
        with torch.no_grad():
            
            # Get strengths of all connections.
            strengths = self.connections_tensor.numpy()
            shape = strengths.shape

            # Determine all combinations of prune dimensions
            all_dims = range(len(shape))
            prune_indxs = [range(shape[d]) if d in self.prune_dims else [slice(None)] for d in all_dims]
            prune_indxs = itertools.product(*prune_indxs)

            # Along all combinations of prune dimensions:
            #    - Keep strongest k1 connections
            #    - Reinitilize trailing k2 - k1 connections.
            k1 = self.k1
            k2 = self.k2
            for idx in prune_indxs:

                # Get top k1'th strength.
                s = strengths[idx].flatten()
                v1 = np.partition(s, -k1)[-k1] # s.kthvalue(len(s) - k1).value

                # Keep top k1'th connection - prune those below
                c = self.weight[idx].flatten()
                prune_mask = (s < v1).astype(np.uint8)
                c[prune_mask] = 0

                # Get trailing k2 - k1 connections.
                v2 = np.partition(s, -k2)[-k2] # s.kthvalue(len(s) - k2).value
                new_mask = (s > v2) & prune_mask

                # Reinitilized trailing k2 - k1 connections.
                # Note: [None, :] is added here as kaiming_uniform_ requires a 2d tensor.
                if len(c[new_mask]) > 0:
                    torch.nn.init.kaiming_uniform_(c[new_mask][None, :])

                # Reshape connections and update the weight.
                self.weight[idx] = c.reshape(self.weight[idx].shape)
                
                self.prune_mask = prune_mask

            # Reset connection strengths.
            self.connections_tensor = torch.zeros_like(self.weight)
            
    def prune_randomly(self):
        
        with torch.no_grad():
            
            prune_mask = torch.rand(self.weight.shape) < 0.85 # prune 15% of weights
            self.weight[prune_mask] = 0
            
            # Reinitialize those that are zero.
            keep_mask = ~prune_mask
            new_mask  = (self.weight == 0) & keep_mask
            new_weights = self.weight[new_mask]
            if len(new_weights) > 0:
                torch.nn.init.kaiming_uniform_(new_weights[None, :])
                self.weight[new_mask] = new_weights
       
    def __call__(self, input_tensor, *args, **kwargs):
        output_tensor = super().__call__(input_tensor, *args, **kwargs)
        if self.learning_iterations % 20 == 0:
            self.update_connections_tensor(input_tensor, output_tensor)
        self.learning_iterations += 1
        return output_tensor

Test Training a Network

The following is a simple toy example copied mostly from: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html

The main difference here is that the network utilizes the SparseCNN module. This exercise moslty servesto gain confidence in the implementation with respect to it's ability to run without errors - this is not concerned with verifying training improvements just yet.

Load Data


In [15]:
root_path = '~/nta/datasets'

In [21]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root=root_path, train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root=root_path, train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=0)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


Files already downloaded and verified
Files already downloaded and verified

Time Trials

Quick test to compare runtime with and without updating the connections tensor.


In [22]:
cd = torch.nn.Conv2d(3, 6, 5)
cs = DSConv2d(3, 6, 5)

dataiter = iter(trainloader)
images, labels = dataiter.next()

print('Dense CNN foward pass:')
%timeit cd(images)
print('DSConv2d foward pass:')
%timeit cs(images)


Dense CNN foward pass:
215 µs ± 3.04 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
DSConv2d foward pass:
3.59 ms ± 37.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Network Setup


In [278]:
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv1 = DSConv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        # self.conv2 = nn.Conv2d(6, 16, 5)
        self.conv2 = DSConv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    
net = Net()

Training


In [279]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):

        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0
            
            net.conv1.progress_connections()
            net.conv2.progress_connections()
    
            break
        
    break
    
#     # Compare with pruning random weights.
#     net.conv1.prune_randomly()
#     net.conv2.prune_randomly() 

print('Finished Training')


[1,  2000] loss: 2.221
Finished Training

Testing Accuracy


In [17]:
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(4):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1


for i in range(10):
    print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))


Accuracy of plane :  0 %
Accuracy of   car :  0 %
Accuracy of  bird : 79 %
Accuracy of   cat :  2 %
Accuracy of  deer : 10 %
Accuracy of   dog :  0 %
Accuracy of  frog :  1 %
Accuracy of horse :  0 %
Accuracy of  ship :  0 %
Accuracy of truck :  0 %