In [1]:
from dpp_nets.layers.layers import *
from dpp_nets.my_torch.utilities import pad_tensor
In [2]:
embd_dim = 10
hidden_dim = 20
kernel_dim = 15
enc_dim = 50
target_dim = 3
alpha_iter = 5
kernel_net = KernelVar(embd_dim, hidden_dim, kernel_dim).double()
#sampler = MarginalSampler()
sampler = ReinforceSampler(alpha_iter)
pred_net = PredNet(embd_dim, hidden_dim, enc_dim, target_dim).double()
In [8]:
batch_size = 4
max_set_size = 7
data = torch.cat([pad_tensor(torch.randn(1,6,embd_dim),1,0,max_set_size),
pad_tensor(torch.randn(1,4,embd_dim),1,0,max_set_size),
pad_tensor(torch.randn(1,7,embd_dim),1,0,max_set_size),
pad_tensor(torch.randn(1,5,embd_dim),1,0,max_set_size)])
words = Variable(data).double()
target = Variable(torch.randn(batch_size, target_dim)).double()
kernel, words = kernel_net(words)
sampler.s_ix = kernel_net.s_ix
sampler.e_ix = kernel_net.e_ix
weighted_words = sampler(kernel, words)
pred_net.s_ix = sampler.s_ix
pred_net.e_ix = sampler.e_ix
pred_net(weighted_words)
# trainer = MarginalTrainer(kernel_net, sampler, pred_net)
trainer = ReinforceTrainer(kernel_net, sampler, pred_net)
trainer.reg = 0
trainer.reg_mean = 3
trainer.activation = nn.Sigmoid()
words = Variable(data).double()
trainer(words, target)
Out[8]:
In [9]:
trainer.parameters
Out[9]:
In [ ]:
t