In [1]:
import os,sys,inspect
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0,parentdir)
from torchnlp.datasets import penn_treebank_dataset
import torch
from torch.utils.data import DataLoader
from rsm_samplers import ptb_pred_sequence_collate
from ptb import lang_util
from torchvision import datasets, transforms
import torch.nn.functional as F
import torchvision.utils as vutils
from torch.nn import CrossEntropyLoss, MSELoss
from importlib import reload
from torch.utils.data import Sampler, BatchSampler
import rsm
from functools import reduce, partial
import matplotlib.pyplot as plt
In [2]:
reload(lang_util)
corpus = lang_util.Corpus('/Users/jgordon/nta/datasets/PTB')
print(corpus.read_out(corpus.train[:60]))
print(len(corpus.valid), len(corpus.test), len(corpus.train))
In [272]:
import rsm_samplers
import rsm
from ptb import lang_util
reload(rsm_samplers)
reload(rsm)
reload(lang_util)
VISUAL = False
DATA_KIND = 'smnist' # 'ptb'
if VISUAL:
MB = 3
EPOCHS = 1
BS = 1
else:
BS = 128
MB = 100
EPOCHS = 10
if DATA_KIND == 'smnist':
d_in = 28 ** 2
d_out = d_in
elif DATA_KIND == 'ptb':
VS = 10000
d_in = d_out = 28
m=20
n=4
k=3
k_winner_cells=1
eps = [0.0, 0.0]
forget_mu = 0.0
N_LAYERS = 2
weight_sparsity = None
tp_boosting = True
feedback = True
decode_from_full_memory = True
model = rsm.RSMNet(n_layers=N_LAYERS,
m=m,
n=n,
d_in=d_in,
d_out=d_out,
embed_dim=d_in,
k=k,
k_winner_cells=k_winner_cells,
eps=eps,
do_inhibition=False,
x_b_norm=True,
boost_strat='col_boosting',
forget_mu=forget_mu,
decode_from_full_memory=decode_from_full_memory,
weight_sparsity=weight_sparsity,
feedback=feedback,
input_bias=True,
loss_layers='first',
tp_boosting=tp_boosting,
visual_debug=VISUAL
)
if DATA_KIND == "smnist":
dataset = rsm_samplers.MNISTBufferedDataset("~/nta/datasets", download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]),)
sampler = rsm_samplers.MNISTSequenceSampler(dataset, sequences=[[0,1,2,3], [0,3,2,1]],
batch_size=BS,
max_batches=MB,
random_mnist_images=True)
loader = DataLoader(dataset,
batch_sampler=sampler,
collate_fn=rsm_samplers.pred_sequence_collate)
elif DATA_KIND == "ptb":
bwe = lang_util.BitwiseWordEmbedding()
corpus = lang_util.Corpus('/Users/jgordon/nta/datasets/PTB')
collate_fn = partial(rsm_samplers.ptb_pred_sequence_collate, vector_dict=bwe.embedding_dict)
sampler = rsm_samplers.PTBSequenceSampler(corpus.train, batch_size=BS, max_batches=MB)
loader = DataLoader(corpus.train,
batch_sampler=sampler,
collate_fn=collate_fn)
criterion = MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
In [273]:
import util
reload(util)
hidden = model.init_hidden(BS)
def _repackage_hidden(h):
"""Wraps hidden states in new Tensors, to detach them from their history."""
if isinstance(h, torch.Tensor):
return h.detach()
else:
return tuple(_repackage_hidden(v) for v in h)
for i in range(EPOCHS):
total_loss = 0.0
for batch_id, (inputs, target, pred_target, pred_input) in enumerate(loader):
# hidden[1].detach_()
# hidden[2].detach_()
hidden = _repackage_hidden(hidden)
out, hidden = model(inputs, hidden)
# We need the prior output to calculate loss at layers i>0
# First layer loss is between last image prediction and actual input image
loss = criterion(out[0], inputs.detach())
# Layers > 1 loss is between last output (hidden predictions) and actual hidden
# ls_loss = criterion(torch.stack(last_output[1:]), x_b[1:].detach())
loss.backward()
util._plot_grad_flow(model)
optimizer.step()
total_loss += loss.item()
if batch_id % 50 == 0:
print("Finished batch %d" % batch_id)
if last_output is not None:
model._plot_tensors([
('last_out', last_output),
('inputs', [inputs.detach(), None]),
('x_b', hidden[0])
])
print('epoch', i, 'loss', total_loss)