In [121]:
%load_ext autoreload
%autoreload 2
In [415]:
import torch
import torch.nn as nn
import torchvision.models as models
from utils import Dataset
import torch.nn.functional as F
from torch import optim
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
rcParams['figure.figsize'] = (12,6)
Setup a baseline to compare k-winners, which is fast to train and evaluate
In [103]:
dataset = Dataset(config=dict(dataset_name='MNIST', data_dir='~/nta/datasets',
batch_size_train=256, batch_size_test=1024))
In [107]:
# torch cross_entropy is log softmax activation + negative log likelihood
loss_func = F.cross_entropy
# a custom Lambda module
class Lambda(nn.Module):
def __init__(self, func):
super().__init__()
self.func = func
def forward(self, x):
return self.func(x)
# simple feedforward model
# use a lambda layer to resize
model = nn.Sequential(
Lambda(lambda x: x.view(-1,28*28)),
nn.Linear(784,100),
nn.ReLU(),
nn.Linear(100,10),
)
# calculate accuracy
def accuracy(loader, num_batches=3):
len_dataset = loader.dataset.data.size()[0]
running_acc = 0
running_count = 0
# do no cover entire dataset. training is shuffled, testing is not
iter_loader = iter(loader)
for _ in range(num_batches):
x,y = next(iter_loader)
out = model(x)
preds = torch.argmax(out, dim=1)
running_acc += (preds == y).float().sum()
running_count += x.size()[0]
return running_acc.item() / running_count
In [497]:
# baseline
def fit(model, dataset, verbose=True, epochs=1, epoch_eval=True):
test_accuracies = []
losses = []
# dataset
train_loader = dataset.train_loader
test_loader = dataset.test_loader
# hyperparams
opt = optim.SGD(model.parameters(), lr=.01, momentum=0.9)
num_batches = 60
# training loop
print("Training Accuracy before training: {:.4f}".format(accuracy(train_loader)))
for epoch in range(epochs):
model.train()
iter_loader = iter(train_loader)
for i in range(num_batches):
x,y = next(iter_loader)
# calculate loss
loss = loss_func(model(x), y)
losses.append(loss.item())
# backpropagate
loss.backward()
# learn
opt.step()
opt.zero_grad()
if verbose:
if i % 20 == 0:
print("Loss: {:.8f}".format(loss.item()*1000/len(x)))
if epoch_eval:
model.eval()
test_acc = accuracy(test_loader)
test_accuracies.append(test_acc)
model.eval()
print("Training Accuracy after training: {:.4f}".format(accuracy(train_loader)))
print("Test Accuracy after training: {:.4f}".format(accuracy(test_loader)))
print("---------------------------")
return losses, test_accuracies
In [109]:
fit(model, dataset)
In [408]:
# from functions import KWinnersBatch as KWinners
from functions import KWinners
model_gen = lambda k: nn.Sequential(
Lambda(lambda x: x.view(-1,28*28)),
nn.Linear(784,100),
KWinners(k_perc=k),
nn.Linear(100,10),
)
model = model_gen(.1)
fit(model, dataset, epochs=1)
In [356]:
for k in np.arange(0.01,1,0.1):
print("K: %f" % k)
model = model_gen(k)
fit(model, dataset, verbose=False)
Model accuracy with 1% of active neurons, with and without boosting:
In [249]:
# no non-linearity required to get a low accuracy
model = nn.Sequential(
Lambda(lambda x: x.view(-1,28*28)),
nn.Linear(784,100),
nn.Linear(100,10),
)
fit(model, dataset)
In [ ]:
Further tests and comparison
In [416]:
# simple CNN Model
non_linearity = nn.ReLU
model = nn.Sequential(
nn.Conv2d(1,32, kernel_size=3, stride=2, padding=1), # 14x14
non_linearity(),
nn.Conv2d(32,64, kernel_size=3, stride=2, padding=1), # 7x7
non_linearity(),
nn.Conv2d(64,128, kernel_size=3, stride=2, padding=1), # 4x4
non_linearity(),
Lambda(lambda x: x.view(x.size(0), -1)), # 128
nn.Linear(128*4*4,10) # 10
)
losses, accs = fit(model, dataset, epochs=20)
In [418]:
# kWinners
from functions import KWinners
model_gen = lambda k: nn.Sequential(
nn.Conv2d(1,32, kernel_size=3, stride=2, padding=1), # 14x14
KWinners(k_perc=k),
nn.Conv2d(32,64, kernel_size=3, stride=2, padding=1), # 7x7
KWinners(k_perc=k),
nn.Conv2d(64,128, kernel_size=3, stride=2, padding=1), # 4x4
KWinners(k_perc=k),
Lambda(lambda x: x.view(x.size(0), -1)), # 128
nn.Linear(128*4*4,10) # 10
)
model = model_gen(k=0.25)
kw_losses, kw_accs = fit(model, dataset, epochs=20)
In [437]:
# kWinners
from functions import KWinners
model_gen = lambda k: nn.Sequential(
nn.Conv2d(1,32, kernel_size=3, stride=2, padding=1), # 14x14
KWinners(k_perc=k, use_absolute=False),
nn.Conv2d(32,64, kernel_size=3, stride=2, padding=1), # 7x7
KWinners(k_perc=k, use_absolute=False),
nn.Conv2d(64,128, kernel_size=3, stride=2, padding=1), # 4x4
KWinners(k_perc=k, use_absolute=False),
Lambda(lambda x: x.view(x.size(0), -1)), # 128
nn.Linear(128*4*4,10) # 10
)
model = model_gen(k=0.25)
kwp_losses, kwp_accs = fit(model, dataset, epochs=20)
In [438]:
plt.plot(kw_losses, label='kw_abs_losses')
plt.plot(kwp_losses, label='kw_pos_losses')
plt.plot(losses, label='losses')
plt.legend();
In [453]:
# kWinners without boosting
from functions import KWinners
model_gen = lambda k: nn.Sequential(
nn.Conv2d(1,32, kernel_size=3, stride=2, padding=1), # 14x14
KWinners(k_perc=k, use_boosting=False),
nn.Conv2d(32,64, kernel_size=3, stride=2, padding=1), # 7x7
KWinners(k_perc=k, use_boosting=False),
nn.Conv2d(64,128, kernel_size=3, stride=2, padding=1), # 4x4
KWinners(k_perc=k, use_boosting=False),
Lambda(lambda x: x.view(x.size(0), -1)), # 128
nn.Linear(128*4*4,10) # 10
)
model = model_gen(k=0.1)
fit(model, dataset, epochs=1, epoch_eval=False);
In [454]:
# kWinners with boosting
from functions import KWinners
model_gen = lambda k: nn.Sequential(
nn.Conv2d(1,32, kernel_size=3, stride=2, padding=1), # 14x14
KWinners(k_perc=k, use_boosting=True),
nn.Conv2d(32,64, kernel_size=3, stride=2, padding=1), # 7x7
KWinners(k_perc=k, use_boosting=True),
nn.Conv2d(64,128, kernel_size=3, stride=2, padding=1), # 4x4
KWinners(k_perc=k, use_boosting=True),
Lambda(lambda x: x.view(x.size(0), -1)), # 128
nn.Linear(128*4*4,10) # 10
)
model = model_gen(k=0.1)
fit(model, dataset, epochs=1, epoch_eval=False);
In [465]:
# Exploring several values for beta
from functions import KWinners
model_gen = lambda k,b: nn.Sequential(
nn.Conv2d(1,32, kernel_size=3, stride=2, padding=1), # 14x14
KWinners(k_perc=k, use_boosting=True, beta=b),
nn.Conv2d(32,64, kernel_size=3, stride=2, padding=1), # 7x7
KWinners(k_perc=k, use_boosting=True, beta=b),
nn.Conv2d(64,128, kernel_size=3, stride=2, padding=1), # 4x4
KWinners(k_perc=k, use_boosting=True, beta=b),
Lambda(lambda x: x.view(x.size(0), -1)), # 128
nn.Linear(128*4*4,10) # 10
)
for b in [0, 0.01, 0.02, 0.05, 0.1]:
model = model_gen(k=0.1, b=b)
print("Beta: %f" % b)
fit(model, dataset, epochs=1, epoch_eval=False);
In [494]:
# Exploring several values for beta
from functions import KWinners
model_gen = lambda k,b: nn.Sequential(
nn.Conv2d(1,32, kernel_size=3, stride=2, padding=1), # 14x14
KWinners(k_perc=k, use_boosting=True, beta=b),
nn.Conv2d(32,64, kernel_size=3, stride=2, padding=1), # 7x7
KWinners(k_perc=k, use_boosting=True, beta=b),
nn.Conv2d(64,128, kernel_size=3, stride=2, padding=1), # 4x4
KWinners(k_perc=k, use_boosting=True, beta=b),
Lambda(lambda x: x.view(x.size(0), -1)), # 128
nn.Linear(128*4*4,10) # 10
)
betas, losses, accs = [], [] ,[]
for b in [0, 0.002, 0.01]:
model = model_gen(k=0.1, b=b)
print("Beta: %f" % b)
loss, acc = fit(model, dataset, epochs=50, epoch_eval=True, verbose=False);
betas.append(b)
losses.append(loss)
accs.append(acc)
Plot loss and acc for different betas and answer the question if boosting helps or hurts
In [498]:
rcParams['figure.figsize'] = (16,8)
for beta, loss in zip(betas, losses):
plt.plot(loss, label=str(beta))
plt.legend();
In [496]:
for beta, acc in zip(betas, accs):
plt.plot(acc, label=str(beta))
plt.legend();
In [ ]:
# topk experimentation
t = torch.randn((4,3,2))
In [467]:
t
Out[467]:
In [468]:
b = torch.ones((3,2)) * 2
In [469]:
b
Out[469]:
In [471]:
t.shape, b.shape
Out[471]:
In [473]:
b.expand((4,3,2))
Out[473]:
In [474]:
t * b.expand((4,3,2))
Out[474]:
In [296]:
# topk experimentation
t = torch.randn((4,3,2))
In [310]:
t
Out[310]:
In [320]:
tx = t.view(t.size()[0], -1)
print(tx.size())
val, _ = torch.kthvalue(tx, 1, dim=-1)
val
Out[320]:
In [ ]:
In [337]:
[t.size()[0]] + [1 for _ in range(len(t.size())-1)]
Out[337]:
In [341]:
t.shape
Out[341]:
In [342]:
(t > val.view(4,1,1)).sum(dim=0).shape
Out[342]:
In [343]:
(t > val.view(4,1,1)).sum(dim=0)
Out[343]:
In [340]:
t > val.view(4,1,1)
Out[340]:
In [328]:
t.shape
Out[328]:
In [330]:
val.view(4,1,1).shape
Out[330]:
In [331]:
val.view(4,1,1)
Out[331]:
In [313]:
val, ind = torch.topk(t, k=1, dim=2)
In [314]:
ind
Out[314]:
In [287]:
ind.shape
Out[287]:
In [288]:
t.shape
Out[288]:
In [289]:
t[ind].shape
Out[289]:
In [285]:
t.shape
Out[285]:
In [290]:
# get indices
# apply indices
mask = torch.zeros_like(t)
mask.scatter(1, ind, 1.)
Out[290]:
In [286]:
ind
Out[286]:
In [509]:
t = torch.randn(4,4,3)
In [510]:
t
Out[510]:
In [512]:
t.topk(2, dim=2)
Out[512]:
In [ ]: