Goal: Investigate how DSNN fares in a toy problem.

Compare following models:

  • Large dense (same architecture as large sparse, but dense)
  • Small dense (same number of params as large sparse, but dense)
  • Large sparse
  • Large sparse + dynamic sparse

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# general imports
import os
import numpy as np

# torch imports
import torch
import torch.optim as optim
import torch.optim.lr_scheduler as schedulers
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchsummary import summary

# nupic research imports
from nupic.research.frameworks.pytorch.image_transforms import RandomNoise
from nupic.torch.modules import KWinners

# local library
from networks_module.base_networks import *
from models_module.base_models import *

# local files
from utils import *
import math

# plotting
import matplotlib.pyplot as plt
from matplotlib import rcParams
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
rcParams['figure.figsize'] = (12,6)

PATH_TO_WHERE_DATASET_WILL_BE_SAVED = PATH = "~/nta/datasets"

Test with kwinners


In [3]:
from models_module.base_models import BaseModel, SparseModel, DSNNMixedHeb
from networks_module.hebbian_networks import MLP, MLPHeb

In [4]:
# load dataset
config = (dict(
    dataset_name="MNIST",
    data_dir="~/nta/datasets",
    test_noise=True
))
dataset = Dataset(config)

In [13]:
test_noise = True
use_kwinners = True
epochs = 15
on_perc = 0.1

# large dense
config = dict(hidden_sizes=[100,100,100], use_kwinners=use_kwinners)
network = MLP(config=config)
config = dict(debug_weights=True)
model = BaseModel(network=network, config=config)
model.setup()
print("\nLarge Dense")
large_dense = model.train(dataset, epochs, test_noise=test_noise);


Large Dense
Train acc: 0.8761, Val acc: 0.9380, Noise acc: 0.9314
Train acc: 0.9564, Val acc: 0.9553, Noise acc: 0.9511
Train acc: 0.9649, Val acc: 0.9460, Noise acc: 0.9415
Train acc: 0.9663, Val acc: 0.9549, Noise acc: 0.9506
Train acc: 0.9675, Val acc: 0.9482, Noise acc: 0.9454
Train acc: 0.9635, Val acc: 0.9266, Noise acc: 0.9207
Train acc: 0.5230, Val acc: 0.0980, Noise acc: 0.0980
Train acc: 0.0987, Val acc: 0.0980, Noise acc: 0.0980
Train acc: 0.0987, Val acc: 0.0980, Noise acc: 0.0980
Train acc: 0.0987, Val acc: 0.0980, Noise acc: 0.0980
Train acc: 0.0987, Val acc: 0.0980, Noise acc: 0.0980
Train acc: 0.0987, Val acc: 0.0980, Noise acc: 0.0980
Train acc: 0.0987, Val acc: 0.0980, Noise acc: 0.0980
Train acc: 0.0987, Val acc: 0.0980, Noise acc: 0.0980
Train acc: 0.0987, Val acc: 0.0980, Noise acc: 0.0980

Debugging the dense model


In [15]:
large_dense


Out[15]:
defaultdict(list,
            {'train_loss': [0.4117379237174988,
              0.1579397918820381,
              0.12979069519837697,
              0.1254622137248516,
              0.13399272019465763,
              0.1677697219669819,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan],
             'train_acc': [0.8760666666666667,
              0.9564166666666667,
              0.9648833333333333,
              0.96625,
              0.96745,
              0.96345,
              0.523,
              0.09871666666666666,
              0.09871666666666666,
              0.09871666666666666,
              0.09871666666666666,
              0.09871666666666666,
              0.09871666666666666,
              0.09871666666666666,
              0.09871666666666666],
             'linear_0_mean': [-0.00019430241081863642,
              0.00027377685182727873,
              0.00041149111348204315,
              0.001242143684066832,
              0.0016640721587464213,
              0.0038216665852814913,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan],
             'linear_0_std': [0.05636117607355118,
              0.0688825398683548,
              0.07854200899600983,
              0.08841497451066971,
              0.09822449088096619,
              0.1112540140748024,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan],
             'linear_1_mean': [-0.019750574603676796,
              -0.023961268365383148,
              -0.02698422595858574,
              -0.02992100454866886,
              -0.03549543768167496,
              -0.03972663730382919,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan],
             'linear_1_std': [0.09968265891075134,
              0.11661508679389954,
              0.12888020277023315,
              0.14254841208457947,
              0.156602680683136,
              0.1741691380739212,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan],
             'linear_2_mean': [-0.017785781994462013,
              -0.019739622250199318,
              -0.022778544574975967,
              -0.02547401376068592,
              -0.02828535996377468,
              -0.03400816768407822,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan],
             'linear_2_std': [0.08039693534374237,
              0.08319339156150818,
              0.08607394248247147,
              0.08896781504154205,
              0.09325534850358963,
              0.1021324023604393,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan],
             'linear_3_mean': [-0.0027669325936585665,
              -0.00264023058116436,
              -0.0025194010231643915,
              -0.002404184779152274,
              -0.002294291974976659,
              -0.0021893957164138556,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan],
             'linear_3_std': [0.18224182724952698,
              0.16757167875766754,
              0.15857206284999847,
              0.14726030826568604,
              0.1397176831960678,
              0.1382737010717392,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan],
             'val_loss': [0.22370526374578475,
              0.17709137363284827,
              0.24873587604761124,
              0.2219540862193331,
              0.3319696503367275,
              0.5825064988330007,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan],
             'val_acc': [0.938,
              0.9553,
              0.946,
              0.9549,
              0.9482,
              0.9266,
              0.098,
              0.098,
              0.098,
              0.098,
              0.098,
              0.098,
              0.098,
              0.098,
              0.098],
             'noise_loss': [0.24669090679883957,
              0.1845904532313347,
              0.2596750955939293,
              0.2241142802603543,
              0.3100352413157001,
              0.581966226541996,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan,
              nan],
             'noise_acc': [0.9314,
              0.9511,
              0.9415,
              0.9506,
              0.9454,
              0.9207,
              0.098,
              0.098,
              0.098,
              0.098,
              0.098,
              0.098,
              0.098,
              0.098,
              0.098]})

In [16]:
results = large_dense
h, w = math.ceil(len(results)/4), 4 
combinations = []
for i in range(h):
    for j in range(w):
        combinations.append((i,j))

fig, axs = plt.subplots(h, w, gridspec_kw={'hspace': 0.5, 'wspace': 0.5})
fig.set_size_inches(16,16)
for (i, j), k in zip(combinations[:len(results)], sorted(results.keys())):
    axs[i, j].plot(range(len(results[k])), results[k])
    axs[i, j].set_title(k)



In [ ]: