In [1]:
    
# general imports
import os
import numpy as np
# torch imports
import torch
import torch.optim as optim
import torch.optim.lr_scheduler as schedulers
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchsummary import summary
# nupic research imports
from nupic.research.frameworks.pytorch.image_transforms import RandomNoise
from nupic.torch.modules import KWinners
PATH_TO_WHERE_DATASET_WILL_BE_SAVED = PATH = "~/nta/datasets"
    
In [2]:
    
class Dataset:
    """Loads a dataset.
    Returns object with a pytorch train and test loader
    """
    def __init__(self, config=None):
        defaults = dict(
            dataset_name='MNIST',
            data_dir=os.path.expanduser(PATH),
            batch_size_train=128,
            batch_size_test=128,
            stats_mean=None,
            stats_std=None,
            augment_images=False,
            test_noise=False,
            noise_level=0.1,
        )
        defaults.update(config or {})
        self.__dict__.update(defaults)
        
        # recover mean and std to normalize dataset
        if not self.stats_mean or not self.stats_std:
            tempset = getattr(datasets, self.dataset_name)(
                root=self.data_dir, train=True, transform=transforms.ToTensor()
            )
            self.stats_mean = (tempset.data.float().mean().item() / 255,)
            self.stats_std = (tempset.data.float().std().item() / 255,)
            del tempset
        # set up transformations
        transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(self.stats_mean, self.stats_std),
            ]
        )
        
        # set up augment transforms for training
        if not self.augment_images:
            aug_transform = transform
        else:
            aug_transform = transforms.Compose(
                [
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize(self.stats_mean, self.stats_std),
                ]
            )
        # load train set
        train_set = getattr(datasets, self.dataset_name)(
            root=self.data_dir, train=True, transform=aug_transform, download=True
        )
        self.train_loader = DataLoader(
            dataset=train_set, batch_size=self.batch_size_train, shuffle=True
        )
        # load test set
        test_set = getattr(datasets, self.dataset_name)(
            root=self.data_dir, train=False, transform=transform, download=True
        )
        self.test_loader = DataLoader(
            dataset=test_set, batch_size=self.batch_size_test, shuffle=False
        )
        # noise dataset
        noise = self.noise_level
        noise_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(self.stats_mean, self.stats_std),
                RandomNoise(
                    noise, high_value=0.5 + 2 * 0.20, low_value=0.5 - 2 * 0.2
                ),
            ]
        )
        noise_set = getattr(datasets, self.dataset_name)(
            root=self.data_dir, train=False, transform=noise_transform
        )
        self.noise_loader = DataLoader(
            dataset=noise_set, batch_size=self.batch_size_test, shuffle=False
        )
    
In [3]:
    
dataset = Dataset()
    
In [4]:
    
class MLP(nn.Module):
    """Simple 3 hidden layers + output MLP"""
    def __init__(self, config=None):
        super(MLP, self).__init__()
        defaults = dict(
            device='cpu',
            input_size=784,
            num_classes=10,
            hidden_sizes=[400, 400, 400],
            batch_norm=False,
            dropout=False,
            use_kwinners=False
        )
        defaults.update(config or {})
        self.__dict__.update(defaults)
        self.device = torch.device(self.device)
        
        # decide which actiovation function to use
        if self.use_kwinners: self.activation_func = self._kwinners
        else: self.activation_func = lambda _: nn.ReLU()
        # create the layers
        layers = [
            *self._linear_block(self.input_size, self.hidden_sizes[0]),
            *self._linear_block(self.hidden_sizes[0],  self.hidden_sizes[1]),
            *self._linear_block(self.hidden_sizes[1], self.hidden_sizes[2]),
            nn.Linear(self.hidden_sizes[2], self.num_classes),
        ]
        self.classifier = nn.Sequential(*layers)
    def _linear_block(self, a, b):
        block = [nn.Linear(a, b), self.activation_func(b)]
        if self.batch_norm: block.append(nn.BatchNorm1d(b))
        if self.dropout: block.append(nn.Dropout(p=self.dropout))
        return block
    def _kwinners(self, num_units):
        return KWinners(
            n=num_units,
            percent_on=0.3,
            boost_strength=1.4,
            boost_strength_factor=0.7,
        )
    
    def forward(self, x):
        # need to flatten input before forward pass
        return self.classifier(x.view(-1, self.input_size))
    
    def alternative_forward(self, x):
        """Replace forward function by this to visualize activations"""
        # need to flatten before forward pass
        x = x.view(-1, self.input_size)
        for layer in self.classifier:
            # apply the transformation
            x = layer(x)
            # do something with the activation
            print(torch.mean(x).item())
        
        return x
    
In [5]:
    
network = MLP()
    
In [6]:
    
summary(network, input_size=(1,28,28))
    
    
In [7]:
    
class BaseModel:
    """Base model, with training loops and logging functions."""
    def __init__(self, network, dataset, config=None):
        defaults = dict(
            learning_rate=0.1,
            momentum=0.9,
            device="cpu",
            lr_scheduler=False,
            sparsity=0,
            weight_decay=1e-4,
            test_noise=False
        )
        defaults.update(config or {})
        self.__dict__.update(defaults)
        self.device = torch.device(self.device)
        self.network = network.to(self.device)
        self.dataset = dataset
        self.setup()
    
        # apply sparsity
        if self.sparsity:
            self.sparse_layers = []
            for layer in self.network.modules():
                if isinstance(layer, nn.Linear):
                    shape = layer.weight.shape
                    mask = (torch.rand(shape) > self.sparsity).float().to(self.device)
                    layer.weight.data *= mask
                    self.sparse_layers.append((layer, mask))
    def setup(self):
        self.optimizer = optim.SGD(
            self.network.parameters(),
            lr=self.learning_rate,
            momentum=self.momentum,
            weight_decay=self.weight_decay,
        )
        # add a learning rate scheduler
        if self.lr_scheduler:
            milestones = [int(self.num_epochs/2), int(num_epochs*3/4)]
            self.lr_scheduler = schedulers.MultiStepLR(
                self.optimizer, milestones=milestones, gamma=0.1
            )
        # init loss function
        self.loss_func = nn.CrossEntropyLoss()
        self.epoch = 0
    
    def train(self, num_epochs, test_noise=False):
        for i in range(num_epochs):
            log = self.run_epoch(test_noise)
            # print acc
            if test_noise:
                print("Train acc: {:.4f}, Val acc: {:.4f}, Noise acc: {:.4f}".format(
                    log['train_acc'], log['val_acc'], log['test_acc']))
            else:
                print("Train acc: {:.4f}, Val acc: {:.4f}".format(
                    log['train_acc'], log['val_acc']))
        
    def run_epoch(self, test_noise):
        log = {}        
        self.epoch += 1
        log['current_epoch'] = self.epoch
        
        # train
        self.network.train()
        log['train_loss'], log['train_acc'] = \
            self._run_one_pass(self.dataset.train_loader, train=True)
        # validate
        self.network.eval()
        log['val_loss'], log['val_acc'] = \
            self._run_one_pass(self.dataset.test_loader, train=False)
        # additional validation for noise
        if test_noise:
            log['test_loss'], log['test_acc'] = \
                self._run_one_pass(self.dataset.noise_loader, train=False, noise=True)
        # any updates post training, e.g. scheduler
        self._post_epoch_updates(dataset)
        
        return log
    def _post_epoch_updates(self, dataset=None):
        # update learning rate
        if self.lr_scheduler:
            self.lr_scheduler.step()
    def _run_one_pass(self, loader, train=True, noise=False):
        epoch_loss = 0
        correct = 0
        for inputs, targets in loader:
            # setup for training
            inputs = inputs.to(self.device)
            targets = targets.to(self.device)
            self.optimizer.zero_grad()
            # training loop
            with torch.set_grad_enabled(train):
                # forward + backward + optimize
                outputs = self.network(inputs)
                _, preds = torch.max(outputs, 1)
                correct += torch.sum(targets == preds).item()
                loss = self.loss_func(outputs, targets)
                if train:
                    loss.backward()
                    self.optimizer.step()
                    # if sparse, apply the mask to weights after optimization
                    if self.sparsity:
                        for layer, mask in self.sparse_layers:
                            layer.weight.data *= mask
            # keep track of loss
            epoch_loss += loss.item() * inputs.size(0)
        # store loss and acc at each pass
        loss = epoch_loss / len(loader.dataset)
        acc = correct / len(loader.dataset)
        return loss, acc
    
In [8]:
    
model = BaseModel(network, dataset, dict(k_winners=False))
    
In [9]:
    
model.train(num_epochs=5, test_noise=True)
    
    
In [10]:
    
# rebuild the network with KWinners as the activation function
network = MLP(dict(kwinners=True))
# build model
model = BaseModel(network, dataset, dict(k_winners=True))
# run model
model.train(num_epochs=5, test_noise=True)
    
    
In [11]:
    
# rebuild the network with KWinners as the activation function
network = MLP(dict(kwinners=True))
# build model
model = BaseModel(network, dataset, 
                  dict(k_winners=True, sparsity=0.8))
# run model
model.train(num_epochs=5, test_noise=True)
    
    
In [12]:
    
for layer in network.modules():
    if isinstance(layer, nn.Linear):
        print(torch.mean(layer.weight).item(), torch.std(layer.weight).item())
    
    
In [13]:
    
network.forward = network.alternative_forward
network(torch.rand(1,1,28,28));
    
    
In [ ]: