In [1]:
import argparse
import os
import shutil

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data.dataloader import DataLoader

from dpp_nets.utils.io import make_embd, make_tensor_dataset
from dpp_nets.layers.layers import DeepSetBaseline

parser = argparse.ArgumentParser(description='Baseline (Deep Sets) Trainer')

parser.add_argument('-a', '--aspect', type=str, choices=['aspect1', 'aspect2', 'aspect3', 'all'],
                    help='what is the target?', required=True)
parser.add_argument('--remote', type=int,
                    help='training locally or on cluster?', required=True)

parser.add_argument('--data_path_local', type=str, default='/Users/Max/data/beer_reviews',
                    help='where is the data folder locally?')
parser.add_argument('--data_path_remote', type=str, default='/cluster/home/paulusm/data/beer_reviews',
                    help='where is the data folder?')

parser.add_argument('--ckp_path_local', type=str, default='/Users/Max/checkpoints/beer_reviews',
                    help='where is the data folder locally?')

parser.add_argument('--ckp_path_remote', type=str, default='/cluster/home/paulusm/checkpoints/beer_reviews',
                    help='where is the data folder?')

parser.add_argument('-b', '--batch-size', default=50, type=int,
                    metavar='N', help='mini-batch size (default: 50)')
parser.add_argument('--epochs', default=100, type=int, metavar='N',
                    help='number of total epochs to run')
#parser.add_argument('--lr-k', '--learning-rate-k', default=0.1, type=float,
#                    metavar='LRk', help='initial learning rate for kernel net')
#parser.add_argument('--lr-p', '--learning-rate-p', default=0.1, type=float,
#                    metavar='LRp', help='initial learning rate for pred net')
parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float,
                    metavar='LR', help='initial learning rate for baseline')
#parser.add_argument('--reg', type=float, required=True,
#                    metavar='reg', help='regularization constant')
#parser.add_argument('--reg-mean', type=float, required=True,
#                    metavar='reg_mean', help='regularization_mean')

def main():

    global args, lowest_loss

    args = parser.parse_args()
    lowest_loss = 100 # arbitrary high number as upper bound for loss

    ### Load data
    if args.remote:
        # print('training remotely')
        train_path = os.path.join(args.data_path_remote, str.join(".",['reviews', args.aspect, 'train.txt.gz']))
        val_path   = os.path.join(args.data_path_remote, str.join(".",['reviews', args.aspect, 'heldout.txt.gz']))
        embd_path = os.path.join(args.data_path_remote, 'review+wiki.filtered.200.txt.gz')

    else:
        # print('training locally')
        train_path = os.path.join(args.data_path_local, str.join(".",['reviews', args.aspect, 'train.txt.gz']))
        val_path   = os.path.join(args.data_path_local, str.join(".",['reviews', args.aspect, 'heldout.txt.gz']))
        embd_path = os.path.join(args.data_path_local, 'review+wiki.filtered.200.txt.gz')

    embd, word_to_ix = make_embd(embd_path)
    train_set = make_tensor_dataset(train_path, word_to_ix)
    val_set = make_tensor_dataset(val_path, word_to_ix)
    print("loaded data")

    torch.manual_seed(0)
    train_loader = DataLoader(train_set, args.batch_size, shuffle=True)
    val_loader = DataLoader(val_set, args.batch_size)
    print("loader defined")

    ### Build model
    # Network parameters
    embd_dim = embd.weight.size(1)
    hidden_dim = 500
    enc_dim = 200
    if args.aspect == 'all':
        target_dim = 3
    else: 
        target_dim = 1

    # Model
    torch.manual_seed(0)
    trainer = MarginalTrainer(embd, hidden_dim, kernel_dim, enc_dim, target_dim)
    trainer.activation = nn.Sigmoid()
    trainer.reg = args.reg
    trainer.reg_mean = args.reg_mean
    print("created trainer")

    # Set-up Training
    params = [{'params': trainer.kernel_net.parameters(), 'lr': args.lr_k},
              {'params': trainer.pred_net.parameters(),   'lr': args.lr_p}]
    optimizer = torch.optim.Adam(params)
    print('set-up optimizer')

    ### Loop
    torch.manual_seed(0)
    print("started loop")
    for epoch in range(args.epochs):

        adjust_learning_rate(optimizer, epoch)

        train(train_loader, trainer, optimizer)        
        loss, pred_loss, reg_loss = validate(val_loader, trainer)
        
        log(epoch, loss, pred_loss, reg_loss)
        print("logged")

        is_best = pred_loss < lowest_loss
        lowest_loss = min(pred_loss, lowest_loss)    
        save = {'epoch:': epoch + 1, 
                'model': 'Marginal Trainer',
                'state_dict': trainer.state_dict(),
                'lowest_loss': lowest_loss,
                'optimizer': optimizer.state_dict()} 

        save_checkpoint(save, is_best)
        print("saved a checkpoint")

    print('*'*20, 'SUCCESS','*'*20)


def train(loader, model, criterion, optimizer):

    model.train()

    for t, (review, target) in enumerate(loader):
        review = Variable(review)

        if args.aspect == 'all':
            target = Variable(target[:,:3])
        else:
            target = Variable(target[:,int(args.aspect[-1])])

        pred = model(review)
        loss = criterion(pred, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print("trained one batch")

def validate(loader, model, criterion):

    model.eval()

    total_loss = 0.0

    for i, (review, target) in enumerate(loader, 1):

        review = Variable(review, volatile=True)

        if args.aspect == 'all':
            target = Variable(target[:,:3], volatile=True)
        else:
            target = Variable(target[:,int(args.aspect[-1])], volatile=True)

        pred = model(review)
        loss = criterion(pred, target)
        
        delta = loss.data[0] - total_loss
        total_loss += (delta / i)

        print("validated one batch")

    return total_loss

def log(epoch, loss):
    string = str.join(" | ", ['Epoch: %d' % (epoch), 'Validation Loss: %.5f' % (loss)])

    if args.remote:
        destination = os.path.join(args.ckp_path_remote, args.aspect + str(args.lr) + 'baseline_log.txt')
    else:
        destination = os.path.join(args.ckp_path_local, args.aspect + str(args.lr) + 'baseline_log.txt')

    with open(destination, 'a') as log:
        log.write(string + '\n')

def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR multiplied by factor 0.1 for every 20 epochs"""
    lr = args.lr * (0.1 ** (epoch // 25))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def save_checkpoint(state, is_best, filename='baseline_checkpoint.pth.tar'):
    """
    State is a dictionary that cotains valuable information to be saved.
    """
    if args.remote:
        destination = os.path.join(args.ckp_path_remote, args.aspect + str(args.lr) + filename)
    else:
        destination = os.path.join(args.ckp_path_local, args.aspect + str(args.lr) + filename)
    
    torch.save(state, destination)
    if is_best:
        if args.remote:
            best_destination = os.path.join(args.ckp_path_remote, args.aspect + str(args.lr) + 'baseline_model_best.pth.tar')
        else:
            best_destination = os.path.join(args.ckp_path_local, args.aspect + str(args.lr) + 'baseline_model_best.pth.tar')
        
        shutil.copyfile(destination, best_destination)

In [2]:
args = parser.parse_args("-a all --remote 0".split())

In [3]:
val_path   = os.path.join(args.data_path_local, str.join(".",['reviews', args.aspect, 'heldout.txt.gz']))
embd_path = os.path.join(args.data_path_local, 'review+wiki.filtered.200.txt.gz')
embd, word_to_ix = make_embd(embd_path)
val_set = make_tensor_dataset(val_path, word_to_ix)
val_loader = DataLoader(val_set, 10000)

In [4]:
# Load model
from dpp_nets.layers.layers import MarginalTrainer, ReinforceSampler

load_path = '/Users/Max/checkpoints/full_backups/no_batch_normalization_elu/'
model_name = 'allreg0.1reg_mean10.0marginal_checkpoint.pth.tar'
a_dict = torch.load(os.path.join(load_path, model_name))
reinstated_model = MarginalTrainer(embd, 200,500,200,3)
#reinstated_model.load_state_dict(a_dict['state_dict'])
for param_name, params in a_dict['state_dict'].items():
    if '.0.' in param_name or '.2.' in param_name or '.4.' in param_name:
        pass
    else:
        exec(str('reinstated_model.'+str(param_name)+' = nn.Parameter(params)'))
        print('once')
reinstated_model.activation = nn.Sigmoid()
reinstated_model.reg = 0.1
reinstated_model.reg_mean = 10
ix_to_word = {}
for k, v in word_to_ix.items():
    ix_to_word[v + 1] = k


once
once
once
once
once
once
once
once
once
once
once
once
once
once
once
once
once
once
once

In [36]:
import random

ix = random.randint(0,10000)
review, target = val_set[ix]
sampler = ReinforceSampler(1)

In [42]:
words = reinstated_model.embd(Variable(review.unsqueeze(0), volatile=True))
kernel, words = reinstated_model.kernel_net(words)
sampler.s_ix = reinstated_model.kernel_net.s_ix
sampler.e_ix = reinstated_model.kernel_net.e_ix
sampler(kernel, words) 

my_sample = sampler.saved_subsets[0][0]
cut_review = review[:my_sample.size(0)]

my_list = list(cut_review.masked_select(my_sample.data.byte()))
original_review = [ix_to_word[i] for i in list(cut_review)]
print(" ".join(original_review))
print(50*'_')
for i in my_list:
    print(ix_to_word[i])


presentation : the brooklyn brewery style logo label with some waffling on about how it is their verison of an ipa . appearance : golden-amber , with a thin white head . smell : loads of spicy mint hop aromas , with a casual sweetness on the backend . taste : their hop use is applied more for flavour , rather than bittering . a very intruiging hop flavour to be found , and an over puckering reaction that runs dry at the end . very leafy , literally . notes : a bit overkill for this style . makes you wonder what they think an ipa is . seems like an over use of lower alpha hops .
__________________________________________________
is
with
a
thin
head
spicy
casual
ipa
seems
.

In [48]:
sampler.s_ix


Out[48]:
[0]

In [41]:
from dpp_nets.my_torch.linalg import custom_decomp, custom_inverse 

# compute marginals
for i, (s, e) in enumerate(zip(sampler.s_ix, sampler.e_ix)):
    # unpack kernel and words
    V = kernel[s:e]
    word = words[s:e]
    # compute marginal kernel K
    #vals, vecs = custom_decomp()(V)
    #K = vecs.mm((1 / (vals + 1)).diag()).mm(vecs.t()) # actually K = (identity - expression) 
    #marginals = (1 - K.diag()).diag() ## need to rewrite custom_decomp to return full svd + correct gradients. 
    # so this is the inefficient way
    identity = Variable(torch.eye(word.size(0)).type(words.data.type()))
    L = V.mm(V.t())
    K = identity - custom_inverse()(L + identity)
    marginals = (K.diag()).diag()
print(marginals)


Variable containing:
 0.9157  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
 0.0000  0.6111  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
 0.0000  0.0000  0.8441  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
 0.0000  0.0000  0.0000  0.9733  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
 0.0000  0.0000  0.0000  0.0000  0.9807  0.0000  0.0000  0.0000  0.0000  0.0000
 0.0000  0.0000  0.0000  0.0000  0.0000  0.9815  0.0000  0.0000  0.0000  0.0000
 0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.9847  0.0000  0.0000  0.0000
 0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.9817  0.0000  0.0000
 0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.9707  0.0000
 0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.9798
[torch.FloatTensor of size 10x10]


In [44]:
sampler.s_ix, sampler.e_ix


Out[44]:
([0], [10])

In [ ]:
import matplotlib.pyplot as plt
x = np.array([1,2,3,4,5,10,15,20])
y = np.array([0.02244, 0.02219,0.02136,0.0206,0.02042,0.02029,0.0203,0.02034])
plt.plot(x,y,marker='o')
plt.xlabel('Expected Set Size (DPP)')
plt.ylabel('MSE')
plt.title('Sparsity vs Loss (All Aspects)')
plt.savefig('Reg.pdf')
plt.show()

In [ ]:
list(a_dict.keys())

In [ ]:
type(a_dict['state_dict'])

In [ ]:
from SALib.sample import sobol_sequence

In [ ]:
a = sobol_sequence.sample(10000,20000)

In [ ]: