In [1]:
from dpp_nets.layers.layers import *
from dpp_nets.my_torch.utilities import pad_tensor

In [2]:
embd_dim = 10
hidden_dim = 20
kernel_dim = 15
enc_dim = 50
target_dim = 3
alpha_iter = 5

kernel_net = KernelVar(embd_dim, hidden_dim, kernel_dim).double()
#sampler = MarginalSampler()
sampler = ReinforceSampler(alpha_iter)
pred_net = PredNet(embd_dim, hidden_dim, enc_dim, target_dim).double()

In [8]:
batch_size = 4
max_set_size = 7
data = torch.cat([pad_tensor(torch.randn(1,6,embd_dim),1,0,max_set_size),
           pad_tensor(torch.randn(1,4,embd_dim),1,0,max_set_size), 
           pad_tensor(torch.randn(1,7,embd_dim),1,0,max_set_size),
        pad_tensor(torch.randn(1,5,embd_dim),1,0,max_set_size)])

words = Variable(data).double()
target = Variable(torch.randn(batch_size, target_dim)).double()
kernel, words = kernel_net(words)
sampler.s_ix = kernel_net.s_ix
sampler.e_ix = kernel_net.e_ix
weighted_words = sampler(kernel, words)
pred_net.s_ix = sampler.s_ix
pred_net.e_ix = sampler.e_ix
pred_net(weighted_words)

# trainer = MarginalTrainer(kernel_net, sampler, pred_net)
trainer = ReinforceTrainer(kernel_net, sampler, pred_net)
trainer.reg = 0
trainer.reg_mean = 3
trainer.activation = nn.Sigmoid()
words = Variable(data).double()
trainer(words, target)


Zero Subset was produced. Re-sample
Zero Subset was produced. Re-sample
Zero Subset was produced. Re-sample
Zero Subset was produced. Re-sample
Zero Subset was produced. Re-sample
Zero Subset was produced. Re-sample
Zero Subset was produced. Re-sample
Zero Subset was produced. Re-sample
Zero Subset was produced. Re-sample
Zero Subset was produced. Re-sample
Zero Subset was produced. Re-sample
Zero Subset was produced. Re-sample
Zero Subset was produced. Re-sample
Zero Subset was produced. Re-sample
Zero Subset was produced. Re-sample
Zero Subset was produced. Re-sample
Out[8]:
Variable containing:
 1.0636
[torch.DoubleTensor of size 1]

In [7]:
trainer.pred


Out[7]:
Variable containing:
-0.1240  0.1900 -0.0065
-0.1560  0.1945 -0.0209
-0.1558  0.2080  0.0148
-0.1558  0.2080  0.0148
-0.1401  0.2110  0.0208
-0.1356  0.2060  0.0075
-0.1330  0.1997  0.0182
-0.1284  0.1896 -0.0207
-0.1244  0.1857 -0.0160
-0.1284  0.1896 -0.0207
-0.1291  0.1926  0.0083
-0.1362  0.1993  0.0139
-0.1362  0.1993  0.0139
-0.0981  0.1791 -0.0072
-0.1291  0.1926  0.0083
-0.1217  0.2021  0.0020
-0.1388  0.2022  0.0183
-0.1217  0.2021  0.0020
-0.1250  0.1836 -0.0084
-0.1236  0.1975 -0.0085
[torch.DoubleTensor of size 20x3]

In [ ]: