In [12]:
import os
import ray
from ray import tune
import torch # to remove later
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import models
import networks
from datasets import PreprocessedSpeechDataLoader, VaryingDataLoader
from nupic.research.frameworks.pytorch.image_transforms import RandomNoise
from torchsummary import summary
import math
import torch
from torch import nn
from torchvision import models
from nupic.torch.modules import Flatten, KWinners, KWinners2d
from networks_module.layers import DSConv2d, RandDSConv2d, SparseConv2d
In [2]:
config = dict(
device=("cuda" if torch.cuda.device_count() > 0 else "cpu"),
dataset_name="PreprocessedGSC",
data_dir="~/nta/datasets/gsc",
batch_size_train=(4, 16),
batch_size_test=1000,
# ----- Network Related ------
# SE
# model=tune.grid_search(["BaseModel", "SparseModel", "DSNNMixedHeb", "DSNNConvHeb"]),
model="DSNNConvHeb",
network="gsc_conv_heb",
# ----- Optimizer Related ----
optim_alg="SGD",
momentum=0,
learning_rate=0.01,
weight_decay=0.01,
lr_scheduler="StepLR",
lr_gamma=0.90,
use_kwinners = True,
# use_kwinners=tune.grid_search([True, False]),
# ----- Dynamic-Sparse Related - FC LAYER -----
epsilon=184.61538/3, # 0.1 in the 1600-1000 linear layer
sparse_linear_only = True,
start_sparse=1,
end_sparse=-1, # don't get last layer
weight_prune_perc=0.15,
hebbian_prune_perc=0.60,
pruning_es=True,
pruning_es_patience=0,
pruning_es_window_size=5,
pruning_es_threshold=0.02,
pruning_interval=1,
# ----- Dynamic-Sparse Related - CONV -----
prune_methods='dynamic',
hebbian_prune_frac=0.99,
magnitude_prune_frac=0.0,
sparsity=0.98,
update_nsteps=50,
prune_dims=tuple(),
# ----- Additional Validation -----
test_noise=False,
noise_level=0.1,
# ----- Debugging -----
debug_weights=True,
debug_sparse=True,
)
In [3]:
network = networks.gsc_conv_heb(config=config)
In [4]:
summary(network, input_size=(1, 32, 32))
In [5]:
network(torch.rand(10,1,32,32));
In [6]:
network.features
Out[6]:
In [7]:
network.classifier
Out[7]:
In [17]:
network.named_modules
Out[17]:
In [18]:
for name, module in network.named_modules():
# if it is a dsconv layer
if isinstance(module, DSConv2d):
print(name)
print(module.__class__)
In [ ]: