In [1]:
import os,sys,inspect
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0,parentdir) 

from torchnlp.datasets import penn_treebank_dataset
import torch
from torchnlp.samplers import BPTTBatchSampler
from torch.utils.data import DataLoader
from rsm_samplers import MNISTSequenceSampler, ptb_pred_sequence_collate
from torchvision import datasets, transforms
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss, MSELoss
from importlib import reload 
from torch.utils.data import Sampler, BatchSampler
import rsm
from matplotlib.lines import Line2D
import numpy as np
import torchvision.utils as vutils
from functools import reduce, partial
import matplotlib.pyplot as plt
from tensorboardX import SummaryWriter

In [2]:
import time
import rsm_samplers
import rsm
import util
import baseline_models
reload(rsm)
reload(rsm_samplers)
reload(util)
reload(baseline_models)

from torch.utils.data import DataLoader, BatchSampler

writer = SummaryWriter()

BSZ = 10
MB = 10
PAGI9 = [[2, 4, 0, 7, 8, 1, 6, 1, 8], [2, 7, 4, 9, 5, 9, 3, 1, 0], [5, 7, 3, 4, 1, 3, 1, 6, 4], [1, 3, 7, 5, 2, 5, 5, 3, 4], [
    2, 9, 1, 9, 2, 8, 3, 2, 7], [1, 2, 6, 4, 8, 3, 5, 0, 3], [3, 8, 0, 5, 6, 4, 1, 3, 9], [4, 7, 5, 3, 7, 6, 7, 2, 4]]

dataset = rsm_samplers.MNISTBufferedDataset("~/nta/datasets", download=True,
                                            transform=transforms.Compose([
                                                transforms.ToTensor(),
                                                transforms.Normalize((0.1307,), (0.3081,))
                                            ]),)

sampler = rsm_samplers.MNISTSequenceSampler(dataset, sequences=PAGI9, 
                                            batch_size=BSZ,
                                            max_batches=MB,
                                            random_mnist_images=True)

loader = DataLoader(dataset,
             batch_sampler=sampler,
             collate_fn=rsm_samplers.pred_sequence_collate)

model = baseline_models.LSTMModel(
                vocab_size=10000,
                embed_dim=100,
                nhid=200,
                d_in=28**2,
                d_out=28**2
            )
predictor = rsm.RSMPredictor(
                d_in=28**2,
                d_out=10000,
                hidden_size=1200
            )

class BPTTTrainer():
    def __init__(self, model, loader, k1=1, k2=30, predictor=None, bsz=BSZ):
        self.k1 = k1
        self.k2 = k2
        self.model = model
        self.loader = loader
        self.predictor = predictor
        self.predictor_loss = CrossEntropyLoss()
        self.optimizer = torch.optim.Adam(params=self.model.parameters(), lr=1e-5)
        self.pred_optimizer = torch.optim.Adam(params=self.predictor.parameters(), lr=1e-5)
        self.bsz = bsz
        self.loss_module = torch.nn.MSELoss()
        self.retain_graph = self.k1 < self.k2
        self.epoch = 0
        
        # Predictor counts
        self.total_samples = 0
        self.correct_samples = 0
        
        self.total_loss = 0.0
        
        if torch.cuda.is_available():
            print("setup: Using cuda")
            self.device = torch.device("cuda")
            torch.cuda.manual_seed(seed)
        else:
            print("setup: Using cpu")
            self.device = torch.device("cpu")
            
        self.model.to(self.device)
        self.predictor.to(self.device)

    def one_step_module(self, inp, hidden):
        out, new_hidden = self.model(inp, hidden)
        return (out, new_hidden)
    
    def _repackage_hidden(self, h):
        """Wraps hidden states in new Tensors, to detach them from their history."""
        if isinstance(h, torch.Tensor):
            return h.detach()
        else:
            return tuple(self._repackage_hidden(v) for v in h)
        
    def predict(self, out, pred_tgts):
        pred_out = self.predictor(out.detach())
        loss = self.predictor_loss(pred_out, pred_tgts)
        loss.backward()
        self.pred_optimizer.step()
        _, class_predictions = torch.max(pred_out, 1)
        self.total_samples += pred_tgts.size(0)
        correct_arr = class_predictions == pred_tgts
        self.correct_samples += correct_arr.sum().item()        
        
    def train(self):
        self.total_samples = 0
        self.correct_samples = 0
        self.total_loss = 0.0

        states = [(None, self.model.init_hidden(self.bsz))]

        outputs = []
        targets = []

        for i, (inp, target, pred_tgts, input_labels) in enumerate(self.loader):
            inp = inp.to(self.device)
            target = target.to(self.device)
            pred_tgts = pred_tgts.to(self.device)

            batch_loss = 0.0
            state = self._repackage_hidden(states[-1][1])
            for h in state:
                h.requires_grad=True
            output, new_state = self.one_step_module(inp, state)

            outputs.append(output)
            targets.append(target)
            while len(outputs) > self.k1:
                # Delete stuff that is too old
                del outputs[0]
                del targets[0]

            states.append((state, new_state))
            while len(states) > self.k2:
                # Delete stuff that is too old
                del states[0]
                
            if (i+1)%self.k1 == 0:
                self.optimizer.zero_grad()
                # backprop last module (keep graph only if they ever overlap)
                start = time.time()
                for j in range(self.k2-1):
                    # print('j', j)
                    if j < self.k1:
                        loss = self.loss_module(outputs[-j-1], targets[-j-1])
                        batch_loss += loss.item()
                        loss.backward(retain_graph=True)

                    # if we get all the way back to the "init_state", stop
                    if states[-j-2][0] is None:
                        break
                    curr_h_grad = states[-j-1][0][0].grad
                    curr_c_grad = states[-j-1][0][1].grad                    
                    states[-j-2][1][0].backward(curr_h_grad, retain_graph=self.retain_graph)
                    states[-j-2][1][1].backward(curr_c_grad, retain_graph=self.retain_graph)                    
                # print("opt step, batch loss: %.3f" % batch_loss)
                self.optimizer.step()
                self.total_loss += batch_loss
            
            self.predict(output, pred_tgts)
        
        if self.total_samples:
            train_acc = 100.*self.correct_samples/self.total_samples
            print(self.epoch, "train acc: %.3f%%, loss: %.3f" % (train_acc, self.total_loss))
            writer.add_scalar('train_acc', train_acc, self.epoch)
            writer.add_scalar('train_loss', self.total_loss, self.epoch)            
        
        self.epoch += 1

In [57]:
trainer = BPTTTrainer(model, loader, predictor=predictor, k1=1, k2=30)
trainer.train()


Pred acc: 0.400%, loss: 50.800

In [7]:
# PAGI9 = [[2, 4, 0, 7, 8, 1, 6, 1, 8], [2, 7, 4, 9, 5, 9, 3, 1, 0], [5, 7, 3, 4, 1, 3, 1, 6, 4], [1, 3, 7, 5, 2, 5, 5, 3, 4], [
#    2, 9, 1, 9, 2, 8, 3, 2, 7], [1, 2, 6, 4, 8, 3, 5, 0, 3], [3, 8, 0, 5, 6, 4, 1, 3, 9], [4, 7, 5, 3, 7, 6, 7, 2, 4]]

for i, (inp, target, pred_tgts, input_labels) in enumerate(loader):
    print(input_labels[0:2])

print('new epoch')
for i, (inp, target, pred_tgts, input_labels) in enumerate(loader):
    print(input_labels[0:2])


tensor([5, 5])
tensor([6, 2])
tensor([4, 5])
tensor([1, 5])
tensor([3, 3])
tensor([9, 4])
tensor([2, 3])
tensor([9, 8])
tensor([1, 0])
tensor([9, 5])
new epoch
tensor([2, 6])
tensor([8, 4])
tensor([3, 1])
tensor([2, 3])
tensor([7, 9])
tensor([2, 1])
tensor([4, 2])
tensor([0, 6])
tensor([7, 4])
tensor([8, 8])

In [19]:
a = torch.rand(3, 5)
print(a)

sums = a.sum(dim=1, keepdim=True)
print(sums)


tensor([[0.5388, 0.3831, 0.0448, 0.9141, 0.7032],
        [0.4118, 0.6417, 0.1732, 0.8667, 0.4823],
        [0.5891, 0.5699, 0.9462, 0.2595, 0.5103]])
tensor([[2.5839],
        [2.5757],
        [2.8750]])

In [39]:
cache = torch.zeros(3, 5)
index = torch.tensor([0, 1, 3])
cache.scatter_(1, index.unsqueeze(1), 1.0)

cache = cache * 0.6

cache.scatter_(1, torch.tensor([1, 2, 1]).unsqueeze(1), 1.0)
print(cache)

print(cache / cache.sum(dim=1, keepdim=True))


tensor([[0.6000, 1.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.6000, 1.0000, 0.0000, 0.0000],
        [0.0000, 1.0000, 0.0000, 0.6000, 0.0000]])
tensor([[0.3750, 0.6250, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3750, 0.6250, 0.0000, 0.0000],
        [0.0000, 0.6250, 0.0000, 0.3750, 0.0000]])