In [1]:
import os
import sys
import numpy as np
import torch
from torchvision import datasets, transforms
from htmresearch.frameworks.pytorch.mnist_sparse_experiment import \
MNISTSparseExperiment
from htmresearch.frameworks.pytorch.model_utils import evaluateModel
from htmresearch.frameworks.pytorch.benchmark_utils import (
register_nonzero_counter, unregister_counter_nonzero)
from tabulate import tabulate
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
import warnings
warnings.filterwarnings('ignore')
In [2]:
%matplotlib inline
def analyzeDutycyclePruning(path, params, minW, minD):
# Dataset transformations used during training. See mnist_sparse_experiment.py
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
# Initialize MNIST test dataset for this experiment
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(params["datadir"], train=False, download=True,
transform=transform),
batch_size=params["test_batch_size"], shuffle=True)
# Load pre-trained model and evaluate with test dataset
model = torch.load(os.path.join(path, "model.pt"), map_location="cpu")
model.pruneWeights(minW)
model.pruneDutycycles(minD)
# Collect nonzero
nonzero = {}
register_nonzero_counter(model, nonzero)
results = evaluateModel(model, test_loader)
unregister_counter_nonzero(model)
table = pd.DataFrame.from_dict(nonzero)
table = pd.DataFrame({"nonzero weight": table.xs("weight")})
table.dropna(inplace=True)
table.drop(["input", "output"], inplace=True)
print(tabulate(table, headers='keys', tablefmt='fancy_grid', numalign="right"))
for k, v in model.named_buffers():
if 'dutyCycle' in k:
dutyCycle = v.view(-1).numpy()
dutyCycle[dutyCycle<=minD] = 0
print k,", min:", dutyCycle.min(), \
", max:",dutyCycle.max(), \
", nonzero:", len(dutyCycle.nonzero()[0])
plt.figure()
pd.Series(dutyCycle).plot(kind='hist', title=k, bins=100)
return model, results["accuracy"]
In [3]:
sys.argv = ['-c', 'experiments_paper.cfg']
suite = MNISTSparseExperiment()
suite.parse_opt()
suite.parse_cfg()
path = suite.get_exp("sparseCNN2")[0]
results = suite.get_exps(path=path)
print "Before prunning"
model, accuracy = analyzeDutycyclePruning(path=path,
params=suite.get_params(results[0]),
minW=0.0,
minD=0.0)
In [4]:
dc = model.linearSdr.linearSdr1.dutyCycle
w = model.linearSdr.linearSdr1.l1.weight
print "accuracy:", accuracy
print "w.nonzero:", torch.nonzero(w).size()[0]
print "dc.nonzero:", torch.nonzero(dc).size()[0]
In [5]:
print "After weight prunning"
model, accuracy = analyzeDutycyclePruning(path=path,
params=suite.get_params(results[0]),
minW=0.03,
minD=0.01)
In [6]:
dc = model.linearSdr.linearSdr1.dutyCycle
w = model.linearSdr.linearSdr1.l1.weight
print "accuracy:", accuracy
print "w.nonzero:", torch.nonzero(w).size()[0]
print "dc.nonzero:", torch.nonzero(dc).size()[0]
In [ ]: