In [1]:
# general imports
import os
import numpy as np

# torch imports
import torch
import torch.optim as optim
import torch.optim.lr_scheduler as schedulers
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchsummary import summary

# nupic research imports
from nupic.research.frameworks.pytorch.image_transforms import RandomNoise
from nupic.torch.modules import KWinners

PATH_TO_WHERE_DATASET_WILL_BE_SAVED = PATH = "~/nta/datasets"

1. Load dataset


In [2]:
class Dataset:
    """Loads a dataset.
    Returns object with a pytorch train and test loader
    """

    def __init__(self, config=None):

        defaults = dict(
            dataset_name='MNIST',
            data_dir=os.path.expanduser(PATH),
            batch_size_train=128,
            batch_size_test=128,
            stats_mean=None,
            stats_std=None,
            augment_images=False,
            test_noise=False,
            noise_level=0.1,
        )
        defaults.update(config or {})
        self.__dict__.update(defaults)
        
        # recover mean and std to normalize dataset
        if not self.stats_mean or not self.stats_std:
            tempset = getattr(datasets, self.dataset_name)(
                root=self.data_dir, train=True, transform=transforms.ToTensor()
            )
            self.stats_mean = (tempset.data.float().mean().item() / 255,)
            self.stats_std = (tempset.data.float().std().item() / 255,)
            del tempset

        # set up transformations
        transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(self.stats_mean, self.stats_std),
            ]
        )
        
        # set up augment transforms for training
        if not self.augment_images:
            aug_transform = transform
        else:
            aug_transform = transforms.Compose(
                [
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize(self.stats_mean, self.stats_std),
                ]
            )

        # load train set
        train_set = getattr(datasets, self.dataset_name)(
            root=self.data_dir, train=True, transform=aug_transform, download=True
        )
        self.train_loader = DataLoader(
            dataset=train_set, batch_size=self.batch_size_train, shuffle=True
        )

        # load test set
        test_set = getattr(datasets, self.dataset_name)(
            root=self.data_dir, train=False, transform=transform, download=True
        )
        self.test_loader = DataLoader(
            dataset=test_set, batch_size=self.batch_size_test, shuffle=False
        )

        # noise dataset
        noise = self.noise_level
        noise_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(self.stats_mean, self.stats_std),
                RandomNoise(
                    noise, high_value=0.5 + 2 * 0.20, low_value=0.5 - 2 * 0.2
                ),
            ]
        )
        noise_set = getattr(datasets, self.dataset_name)(
            root=self.data_dir, train=False, transform=noise_transform
        )
        self.noise_loader = DataLoader(
            dataset=noise_set, batch_size=self.batch_size_test, shuffle=False
        )

In [3]:
dataset = Dataset()

2. Build Network


In [4]:
class MLP(nn.Module):
    """Simple 3 hidden layers + output MLP"""

    def __init__(self, config=None):
        super(MLP, self).__init__()

        defaults = dict(
            device='cpu',
            input_size=784,
            num_classes=10,
            hidden_sizes=[400, 400, 400],
            batch_norm=False,
            dropout=False,
            use_kwinners=False
        )
        defaults.update(config or {})
        self.__dict__.update(defaults)
        self.device = torch.device(self.device)
        
        # decide which actiovation function to use
        if self.use_kwinners: self.activation_func = self._kwinners
        else: self.activation_func = lambda _: nn.ReLU()

        # create the layers
        layers = [
            *self._linear_block(self.input_size, self.hidden_sizes[0]),
            *self._linear_block(self.hidden_sizes[0],  self.hidden_sizes[1]),
            *self._linear_block(self.hidden_sizes[1], self.hidden_sizes[2]),
            nn.Linear(self.hidden_sizes[2], self.num_classes),
        ]
        self.classifier = nn.Sequential(*layers)

    def _linear_block(self, a, b):
        block = [nn.Linear(a, b), self.activation_func(b)]
        if self.batch_norm: block.append(nn.BatchNorm1d(b))
        if self.dropout: block.append(nn.Dropout(p=self.dropout))
        return block

    def _kwinners(self, num_units):
        return KWinners(
            n=num_units,
            percent_on=0.3,
            boost_strength=1.4,
            boost_strength_factor=0.7,
        )
    
    def forward(self, x):
        # need to flatten input before forward pass
        return self.classifier(x.view(-1, self.input_size))
    
    def alternative_forward(self, x):
        """Replace forward function by this to visualize activations"""
        # need to flatten before forward pass
        x = x.view(-1, self.input_size)
        for layer in self.classifier:
            # apply the transformation
            x = layer(x)
            # do something with the activation
            print(torch.mean(x).item())
        
        return x

In [5]:
network = MLP()

In [6]:
summary(network, input_size=(1,28,28))


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Linear-1                  [-1, 400]         314,000
              ReLU-2                  [-1, 400]               0
            Linear-3                  [-1, 400]         160,400
              ReLU-4                  [-1, 400]               0
            Linear-5                  [-1, 400]         160,400
              ReLU-6                  [-1, 400]               0
            Linear-7                   [-1, 10]           4,010
================================================================
Total params: 638,810
Trainable params: 638,810
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.02
Params size (MB): 2.44
Estimated Total Size (MB): 2.46
----------------------------------------------------------------

3. Build Model


In [7]:
class BaseModel:
    """Base model, with training loops and logging functions."""

    def __init__(self, network, dataset, config=None):
        defaults = dict(
            learning_rate=0.1,
            momentum=0.9,
            device="cpu",
            lr_scheduler=False,
            sparsity=0,
            weight_decay=1e-4,
            test_noise=False
        )
        defaults.update(config or {})
        self.__dict__.update(defaults)

        self.device = torch.device(self.device)
        self.network = network.to(self.device)
        self.dataset = dataset
        self.setup()
    
        # apply sparsity
        if self.sparsity:
            self.sparse_layers = []
            for layer in self.network.modules():
                if isinstance(layer, nn.Linear):
                    shape = layer.weight.shape
                    mask = (torch.rand(shape) > self.sparsity).float().to(self.device)
                    layer.weight.data *= mask
                    self.sparse_layers.append((layer, mask))

    def setup(self):
        self.optimizer = optim.SGD(
            self.network.parameters(),
            lr=self.learning_rate,
            momentum=self.momentum,
            weight_decay=self.weight_decay,
        )

        # add a learning rate scheduler
        if self.lr_scheduler:
            milestones = [int(self.num_epochs/2), int(num_epochs*3/4)]
            self.lr_scheduler = schedulers.MultiStepLR(
                self.optimizer, milestones=milestones, gamma=0.1
            )

        # init loss function
        self.loss_func = nn.CrossEntropyLoss()
        self.epoch = 0
    
    def train(self, num_epochs, test_noise=False):
        for i in range(num_epochs):
            log = self.run_epoch(test_noise)
            # print acc
            if test_noise:
                print("Train acc: {:.4f}, Val acc: {:.4f}, Noise acc: {:.4f}".format(
                    log['train_acc'], log['val_acc'], log['test_acc']))
            else:
                print("Train acc: {:.4f}, Val acc: {:.4f}".format(
                    log['train_acc'], log['val_acc']))
        
    def run_epoch(self, test_noise):

        log = {}        
        self.epoch += 1
        log['current_epoch'] = self.epoch
        
        # train
        self.network.train()
        log['train_loss'], log['train_acc'] = \
            self._run_one_pass(self.dataset.train_loader, train=True)
        # validate
        self.network.eval()
        log['val_loss'], log['val_acc'] = \
            self._run_one_pass(self.dataset.test_loader, train=False)
        # additional validation for noise
        if test_noise:
            log['test_loss'], log['test_acc'] = \
                self._run_one_pass(self.dataset.noise_loader, train=False, noise=True)

        # any updates post training, e.g. scheduler
        self._post_epoch_updates(dataset)
        
        return log

    def _post_epoch_updates(self, dataset=None):
        # update learning rate
        if self.lr_scheduler:
            self.lr_scheduler.step()

    def _run_one_pass(self, loader, train=True, noise=False):

        epoch_loss = 0
        correct = 0
        for inputs, targets in loader:
            # setup for training
            inputs = inputs.to(self.device)
            targets = targets.to(self.device)
            self.optimizer.zero_grad()
            # training loop
            with torch.set_grad_enabled(train):
                # forward + backward + optimize
                outputs = self.network(inputs)
                _, preds = torch.max(outputs, 1)
                correct += torch.sum(targets == preds).item()
                loss = self.loss_func(outputs, targets)
                if train:
                    loss.backward()
                    self.optimizer.step()
                    # if sparse, apply the mask to weights after optimization
                    if self.sparsity:
                        for layer, mask in self.sparse_layers:
                            layer.weight.data *= mask

            # keep track of loss
            epoch_loss += loss.item() * inputs.size(0)

        # store loss and acc at each pass
        loss = epoch_loss / len(loader.dataset)
        acc = correct / len(loader.dataset)

        return loss, acc

In [8]:
model = BaseModel(network, dataset, dict(k_winners=False))

4. Train regular dense network with ReLU


In [9]:
model.train(num_epochs=5, test_noise=True)


Train acc: 0.9137, Val acc: 0.9685, Noise acc: 0.9643
Train acc: 0.9705, Val acc: 0.9721, Noise acc: 0.9678
Train acc: 0.9780, Val acc: 0.9727, Noise acc: 0.9674
Train acc: 0.9830, Val acc: 0.9736, Noise acc: 0.9708
Train acc: 0.9856, Val acc: 0.9733, Noise acc: 0.9694

5. KWinnners instead of ReLU


In [10]:
# rebuild the network with KWinners as the activation function
network = MLP(dict(kwinners=True))
# build model
model = BaseModel(network, dataset, dict(k_winners=True))
# run model
model.train(num_epochs=5, test_noise=True)


Train acc: 0.9096, Val acc: 0.9626, Noise acc: 0.9600
Train acc: 0.9673, Val acc: 0.9682, Noise acc: 0.9658
Train acc: 0.9780, Val acc: 0.9753, Noise acc: 0.9724
Train acc: 0.9820, Val acc: 0.9752, Noise acc: 0.9729
Train acc: 0.9859, Val acc: 0.9792, Noise acc: 0.9774

6. KWinners + Sparse


In [11]:
# rebuild the network with KWinners as the activation function
network = MLP(dict(kwinners=True))
# build model
model = BaseModel(network, dataset, 
                  dict(k_winners=True, sparsity=0.8))
# run model
model.train(num_epochs=5, test_noise=True)


Train acc: 0.7464, Val acc: 0.9506, Noise acc: 0.9453
Train acc: 0.9638, Val acc: 0.9687, Noise acc: 0.9620
Train acc: 0.9754, Val acc: 0.9745, Noise acc: 0.9669
Train acc: 0.9819, Val acc: 0.9771, Noise acc: 0.9720
Train acc: 0.9858, Val acc: 0.9748, Noise acc: 0.9643

7. Inspecting

  • Inspect weights after training

In [12]:
for layer in network.modules():
    if isinstance(layer, nn.Linear):
        print(torch.mean(layer.weight).item(), torch.std(layer.weight).item())


-0.00037051260005682707 0.031048087403178215
0.00023493298795074224 0.029837530106306076
0.0005359541974030435 0.023514844477176666
0.0012203083606436849 0.11726776510477066
  • Inspect activations during the forward pass (can use it to inspect during training as well)

In [13]:
network.forward = network.alternative_forward
network(torch.rand(1,1,28,28));


-0.18638744950294495
0.17727136611938477
0.017351068556308746
0.10145770758390427
0.05837811529636383
0.11567458510398865
-1.3849645853042603

In [ ]: