In [1]:
import math

import torch
from torch import nn
from torchvision import models
from nupic.torch.modules import KWinners, KWinners2d

In [2]:
class Lambda(nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func = func
    
    def forward(self, x): 
        return self.func(x)

def Flatten():
    return Lambda(lambda x: x.view((x.size(0), -1)))

In [3]:
class GSCHeb(nn.Module):
    """
    Simple 3 hidden layers + output MLPHeb, similar to one used in SET Paper.
    """

    def __init__(self, config=None):
        super(GSCHeb, self).__init__()

        defaults = dict(
            device='cpu',
            input_size=1024,
            num_classes=12,
            hebbian_learning=True,
            boost_strength=1.5,
            boost_strength_factor=0.9,
            k_inference_factor=1.5,
            duty_cycle_period=1000
        )
        defaults.update(config or {})
        self.__dict__.update(defaults)
        self.device = torch.device(self.device)

        # hidden layers
        layers = [
            *self._conv_block(1, 64, percent_on=0.095), # 28x28 -> 14x14
            *self._conv_block(64, 64, percent_on=0.125), # 10x10 -> 5x5
            Flatten(),
            *self._linear_block(25*64, 1000, percent_on=0.1),
        ]
        # output layer
        layers.append(nn.Linear(1000, self.num_classes))

        # classifier (*redundancy on layers to facilitate traversing)
        self.layers = layers
        self.classifier = nn.Sequential(*layers)

        # track correlations
        self.correlations = []

    def _conv_block(self, fin, fout, percent_on=0.1):
        block = [
            nn.Conv2d(fin, fout, kernel_size=5, stride=1, padding=0),
            nn.BatchNorm2d(fout, affine=False),
            nn.MaxPool2d(kernel_size=2, stride=2),       
            self._kwinners(fout, percent_on),
        ]
        return block

    def _linear_block(self, fin, fout, percent_on=0.1):
        block = [
            nn.Linear(fin, fout), 
            nn.BatchNorm1d(fout, affine=False),
            self._kwinners(fout, percent_on, twod=False),
        ]
        return block

    def _kwinners(self, fout, percent_on, twod=True):
        if twod:
            activation_func = KWinners2d
        else:
            activation_func = KWinners
        return activation_func(
            fout,
            percent_on=percent_on,
            boost_strength=self.boost_strength,
            boost_strength_factor=self.boost_strength_factor,
            k_inference_factor=self.k_inference_factor,
            duty_cycle_period=self.duty_cycle_period
        )

    def _has_activation(self, idx, layer):
        return (
            idx == len(self.layers) - 1
            or isinstance(layer, KWinners)
        )

    def forward(self, x):
        """A faster and approximate way to track correlations"""
        # x = x.view(-1, self.input_size)  # resiaze if needed, eg mnist
        prev_act = (x > 0).detach().float()
        idx_activation = 0
        for idx_layer, layer in enumerate(self.layers):
            # do the forward calculation normally
            x = layer(x)
            if self.hebbian_learning:
                n_samples = x.shape[0]
                if self._has_activation(idx_layer, layer):
                    with torch.no_grad():
                        curr_act = (x > 0).detach().float()
                        # add outer product to the correlations, per sample
                        for s in range(n_samples):
                            outer = torch.ger(prev_act[s], curr_act[s])
                            if idx_activation + 1 > len(self.correlations):
                                self.correlations.append(outer)
                            else:
                                self.correlations[idx_activation] += outer
                        # reassigning to the next
                        prev_act = curr_act
                        # move to next activation
                        idx_activation += 1

        return x

In [4]:
from torchsummary import summary

In [5]:
model = GSCHeb()
summary(model, (1, 32, 32))


---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-5-50977db11cdd> in <module>
      1 model = GSCHeb()
----> 2 summary(model, (1, 32, 32))

~/miniconda3/envs/numenta/lib/python3.7/site-packages/torchsummary/torchsummary.py in summary(model, input_size, batch_size, device)
     70     # make a forward pass
     71     # print(x.shape)
---> 72     model(*x)
     73 
     74     # remove these hooks

~/miniconda3/envs/numenta/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    545             result = self._slow_forward(*input, **kwargs)
    546         else:
--> 547             result = self.forward(*input, **kwargs)
    548         for hook in self._forward_hooks.values():
    549             hook_result = hook(self, input, result)

<ipython-input-3-de48f15e580d> in forward(self, x)
     91                         # add outer product to the correlations, per sample
     92                         for s in range(n_samples):
---> 93                             outer = torch.ger(prev_act[s], curr_act[s])
     94                             if idx_activation + 1 > len(self.correlations):
     95                                 self.correlations.append(outer)

RuntimeError: vector and vector expected, got 3D, 1D tensors at ../aten/src/TH/generic/THTensorMath.cpp:886

In [ ]: