In [1]:
import math
import torch
from torch import nn
from torchvision import models
from nupic.torch.modules import KWinners, KWinners2d
In [2]:
class Lambda(nn.Module):
def __init__(self, func):
super().__init__()
self.func = func
def forward(self, x):
return self.func(x)
def Flatten():
return Lambda(lambda x: x.view((x.size(0), -1)))
In [3]:
class GSCHeb(nn.Module):
"""
Simple 3 hidden layers + output MLPHeb, similar to one used in SET Paper.
"""
def __init__(self, config=None):
super(GSCHeb, self).__init__()
defaults = dict(
device='cpu',
input_size=1024,
num_classes=12,
hebbian_learning=True,
boost_strength=1.5,
boost_strength_factor=0.9,
k_inference_factor=1.5,
duty_cycle_period=1000
)
defaults.update(config or {})
self.__dict__.update(defaults)
self.device = torch.device(self.device)
# hidden layers
layers = [
*self._conv_block(1, 64, percent_on=0.095), # 28x28 -> 14x14
*self._conv_block(64, 64, percent_on=0.125), # 10x10 -> 5x5
Flatten(),
*self._linear_block(25*64, 1000, percent_on=0.1),
]
# output layer
layers.append(nn.Linear(1000, self.num_classes))
# classifier (*redundancy on layers to facilitate traversing)
self.layers = layers
self.classifier = nn.Sequential(*layers)
# track correlations
self.correlations = []
def _conv_block(self, fin, fout, percent_on=0.1):
block = [
nn.Conv2d(fin, fout, kernel_size=5, stride=1, padding=0),
nn.BatchNorm2d(fout, affine=False),
nn.MaxPool2d(kernel_size=2, stride=2),
self._kwinners(fout, percent_on),
]
return block
def _linear_block(self, fin, fout, percent_on=0.1):
block = [
nn.Linear(fin, fout),
nn.BatchNorm1d(fout, affine=False),
self._kwinners(fout, percent_on, twod=False),
]
return block
def _kwinners(self, fout, percent_on, twod=True):
if twod:
activation_func = KWinners2d
else:
activation_func = KWinners
return activation_func(
fout,
percent_on=percent_on,
boost_strength=self.boost_strength,
boost_strength_factor=self.boost_strength_factor,
k_inference_factor=self.k_inference_factor,
duty_cycle_period=self.duty_cycle_period
)
def _has_activation(self, idx, layer):
return (
idx == len(self.layers) - 1
or isinstance(layer, KWinners)
)
def forward(self, x):
"""A faster and approximate way to track correlations"""
# x = x.view(-1, self.input_size) # resiaze if needed, eg mnist
prev_act = (x > 0).detach().float()
idx_activation = 0
for idx_layer, layer in enumerate(self.layers):
# do the forward calculation normally
x = layer(x)
if self.hebbian_learning:
n_samples = x.shape[0]
if self._has_activation(idx_layer, layer):
with torch.no_grad():
curr_act = (x > 0).detach().float()
# add outer product to the correlations, per sample
for s in range(n_samples):
outer = torch.ger(prev_act[s], curr_act[s])
if idx_activation + 1 > len(self.correlations):
self.correlations.append(outer)
else:
self.correlations[idx_activation] += outer
# reassigning to the next
prev_act = curr_act
# move to next activation
idx_activation += 1
return x
In [4]:
from torchsummary import summary
In [5]:
model = GSCHeb()
summary(model, (1, 32, 32))
In [ ]: