In [1]:
# general imports
import os
import numpy as np
# torch imports
import torch
import torch.optim as optim
import torch.optim.lr_scheduler as schedulers
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchsummary import summary
# nupic research imports
from nupic.research.frameworks.pytorch.image_transforms import RandomNoise
from nupic.torch.modules import KWinners
PATH_TO_WHERE_DATASET_WILL_BE_SAVED = PATH = "~/nta/datasets"
In [2]:
class Dataset:
"""Loads a dataset.
Returns object with a pytorch train and test loader
"""
def __init__(self, config=None):
defaults = dict(
dataset_name='MNIST',
data_dir=os.path.expanduser(PATH),
batch_size_train=128,
batch_size_test=128,
stats_mean=None,
stats_std=None,
augment_images=False,
test_noise=False,
noise_level=0.1,
)
defaults.update(config or {})
self.__dict__.update(defaults)
# recover mean and std to normalize dataset
if not self.stats_mean or not self.stats_std:
tempset = getattr(datasets, self.dataset_name)(
root=self.data_dir, train=True, transform=transforms.ToTensor()
)
self.stats_mean = (tempset.data.float().mean().item() / 255,)
self.stats_std = (tempset.data.float().std().item() / 255,)
del tempset
# set up transformations
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(self.stats_mean, self.stats_std),
]
)
# set up augment transforms for training
if not self.augment_images:
aug_transform = transform
else:
aug_transform = transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(self.stats_mean, self.stats_std),
]
)
# load train set
train_set = getattr(datasets, self.dataset_name)(
root=self.data_dir, train=True, transform=aug_transform, download=True
)
self.train_loader = DataLoader(
dataset=train_set, batch_size=self.batch_size_train, shuffle=True
)
# load test set
test_set = getattr(datasets, self.dataset_name)(
root=self.data_dir, train=False, transform=transform, download=True
)
self.test_loader = DataLoader(
dataset=test_set, batch_size=self.batch_size_test, shuffle=False
)
# noise dataset
noise = self.noise_level
noise_transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(self.stats_mean, self.stats_std),
RandomNoise(
noise, high_value=0.5 + 2 * 0.20, low_value=0.5 - 2 * 0.2
),
]
)
noise_set = getattr(datasets, self.dataset_name)(
root=self.data_dir, train=False, transform=noise_transform
)
self.noise_loader = DataLoader(
dataset=noise_set, batch_size=self.batch_size_test, shuffle=False
)
In [3]:
dataset = Dataset()
In [4]:
class MLP(nn.Module):
"""Simple 3 hidden layers + output MLP"""
def __init__(self, config=None):
super(MLP, self).__init__()
defaults = dict(
device='cpu',
input_size=784,
num_classes=10,
hidden_sizes=[400, 400, 400],
batch_norm=False,
dropout=False,
use_kwinners=False
)
defaults.update(config or {})
self.__dict__.update(defaults)
self.device = torch.device(self.device)
# decide which actiovation function to use
if self.use_kwinners: self.activation_func = self._kwinners
else: self.activation_func = lambda _: nn.ReLU()
# create the layers
layers = [
*self._linear_block(self.input_size, self.hidden_sizes[0]),
*self._linear_block(self.hidden_sizes[0], self.hidden_sizes[1]),
*self._linear_block(self.hidden_sizes[1], self.hidden_sizes[2]),
nn.Linear(self.hidden_sizes[2], self.num_classes),
]
self.classifier = nn.Sequential(*layers)
def _linear_block(self, a, b):
block = [nn.Linear(a, b), self.activation_func(b)]
if self.batch_norm: block.append(nn.BatchNorm1d(b))
if self.dropout: block.append(nn.Dropout(p=self.dropout))
return block
def _kwinners(self, num_units):
return KWinners(
n=num_units,
percent_on=0.3,
boost_strength=1.4,
boost_strength_factor=0.7,
)
def forward(self, x):
# need to flatten input before forward pass
return self.classifier(x.view(-1, self.input_size))
def alternative_forward(self, x):
"""Replace forward function by this to visualize activations"""
# need to flatten before forward pass
x = x.view(-1, self.input_size)
for layer in self.classifier:
# apply the transformation
x = layer(x)
# do something with the activation
print(torch.mean(x).item())
return x
In [5]:
network = MLP()
In [6]:
summary(network, input_size=(1,28,28))
In [7]:
class BaseModel:
"""Base model, with training loops and logging functions."""
def __init__(self, network, dataset, config=None):
defaults = dict(
learning_rate=0.1,
momentum=0.9,
device="cpu",
lr_scheduler=False,
sparsity=0,
weight_decay=1e-4,
test_noise=False
)
defaults.update(config or {})
self.__dict__.update(defaults)
self.device = torch.device(self.device)
self.network = network.to(self.device)
self.dataset = dataset
self.setup()
# apply sparsity
if self.sparsity:
self.sparse_layers = []
for layer in self.network.modules():
if isinstance(layer, nn.Linear):
shape = layer.weight.shape
mask = (torch.rand(shape) > self.sparsity).float().to(self.device)
layer.weight.data *= mask
self.sparse_layers.append((layer, mask))
def setup(self):
self.optimizer = optim.SGD(
self.network.parameters(),
lr=self.learning_rate,
momentum=self.momentum,
weight_decay=self.weight_decay,
)
# add a learning rate scheduler
if self.lr_scheduler:
milestones = [int(self.num_epochs/2), int(num_epochs*3/4)]
self.lr_scheduler = schedulers.MultiStepLR(
self.optimizer, milestones=milestones, gamma=0.1
)
# init loss function
self.loss_func = nn.CrossEntropyLoss()
self.epoch = 0
def train(self, num_epochs, test_noise=False):
for i in range(num_epochs):
log = self.run_epoch(test_noise)
# print acc
if test_noise:
print("Train acc: {:.4f}, Val acc: {:.4f}, Noise acc: {:.4f}".format(
log['train_acc'], log['val_acc'], log['test_acc']))
else:
print("Train acc: {:.4f}, Val acc: {:.4f}".format(
log['train_acc'], log['val_acc']))
def run_epoch(self, test_noise):
log = {}
self.epoch += 1
log['current_epoch'] = self.epoch
# train
self.network.train()
log['train_loss'], log['train_acc'] = \
self._run_one_pass(self.dataset.train_loader, train=True)
# validate
self.network.eval()
log['val_loss'], log['val_acc'] = \
self._run_one_pass(self.dataset.test_loader, train=False)
# additional validation for noise
if test_noise:
log['test_loss'], log['test_acc'] = \
self._run_one_pass(self.dataset.noise_loader, train=False, noise=True)
# any updates post training, e.g. scheduler
self._post_epoch_updates(dataset)
return log
def _post_epoch_updates(self, dataset=None):
# update learning rate
if self.lr_scheduler:
self.lr_scheduler.step()
def _run_one_pass(self, loader, train=True, noise=False):
epoch_loss = 0
correct = 0
for inputs, targets in loader:
# setup for training
inputs = inputs.to(self.device)
targets = targets.to(self.device)
self.optimizer.zero_grad()
# training loop
with torch.set_grad_enabled(train):
# forward + backward + optimize
outputs = self.network(inputs)
_, preds = torch.max(outputs, 1)
correct += torch.sum(targets == preds).item()
loss = self.loss_func(outputs, targets)
if train:
loss.backward()
self.optimizer.step()
# if sparse, apply the mask to weights after optimization
if self.sparsity:
for layer, mask in self.sparse_layers:
layer.weight.data *= mask
# keep track of loss
epoch_loss += loss.item() * inputs.size(0)
# store loss and acc at each pass
loss = epoch_loss / len(loader.dataset)
acc = correct / len(loader.dataset)
return loss, acc
In [8]:
model = BaseModel(network, dataset, dict(k_winners=False))
In [9]:
model.train(num_epochs=5, test_noise=True)
In [10]:
# rebuild the network with KWinners as the activation function
network = MLP(dict(kwinners=True))
# build model
model = BaseModel(network, dataset, dict(k_winners=True))
# run model
model.train(num_epochs=5, test_noise=True)
In [11]:
# rebuild the network with KWinners as the activation function
network = MLP(dict(kwinners=True))
# build model
model = BaseModel(network, dataset,
dict(k_winners=True, sparsity=0.8))
# run model
model.train(num_epochs=5, test_noise=True)
In [12]:
for layer in network.modules():
if isinstance(layer, nn.Linear):
print(torch.mean(layer.weight).item(), torch.std(layer.weight).item())
In [13]:
network.forward = network.alternative_forward
network(torch.rand(1,1,28,28));
In [ ]: