In [1]:
from dpp_nets.layers.layers import *
from dpp_nets.my_torch.utilities import pad_tensor

In [2]:
embd_dim = 10
hidden_dim = 20
kernel_dim = 15
enc_dim = 50
target_dim = 3
alpha_iter = 5

kernel_net = KernelVar(embd_dim, hidden_dim, kernel_dim).double()
#sampler = MarginalSampler()
sampler = ReinforceSampler(alpha_iter)
pred_net = PredNet(embd_dim, hidden_dim, enc_dim, target_dim).double()

In [8]:
batch_size = 4
max_set_size = 7
data = torch.cat([pad_tensor(torch.randn(1,6,embd_dim),1,0,max_set_size),
           pad_tensor(torch.randn(1,4,embd_dim),1,0,max_set_size), 
           pad_tensor(torch.randn(1,7,embd_dim),1,0,max_set_size),
        pad_tensor(torch.randn(1,5,embd_dim),1,0,max_set_size)])

words = Variable(data).double()
target = Variable(torch.randn(batch_size, target_dim)).double()
kernel, words = kernel_net(words)
sampler.s_ix = kernel_net.s_ix
sampler.e_ix = kernel_net.e_ix
weighted_words = sampler(kernel, words)
pred_net.s_ix = sampler.s_ix
pred_net.e_ix = sampler.e_ix
pred_net(weighted_words)

# trainer = MarginalTrainer(kernel_net, sampler, pred_net)
trainer = ReinforceTrainer(kernel_net, sampler, pred_net)
trainer.reg = 0
trainer.reg_mean = 3
trainer.activation = nn.Sigmoid()
words = Variable(data).double()
trainer(words, target)


Zero Subset was produced. Re-sample
Zero Subset was produced. Re-sample
Zero Subset was produced. Re-sample
Zero Subset was produced. Re-sample
Zero Subset was produced. Re-sample
Zero Subset was produced. Re-sample
Zero Subset was produced. Re-sample
Zero Subset was produced. Re-sample
Zero Subset was produced. Re-sample
Zero Subset was produced. Re-sample
Zero Subset was produced. Re-sample
Zero Subset was produced. Re-sample
Zero Subset was produced. Re-sample
Zero Subset was produced. Re-sample
Zero Subset was produced. Re-sample
Zero Subset was produced. Re-sample
Out[8]:
Variable containing:
 1.0636
[torch.DoubleTensor of size 1]

In [9]:
trainer.parameters


Out[9]:
<bound method Module.parameters of ReinforceTrainer (
  (kernel_net): KernelVar (
    (layer1): Linear (20 -> 20)
    (layer2): Linear (20 -> 20)
    (layer3): Linear (20 -> 15)
    (net): Sequential (
      (0): Linear (20 -> 20)
      (1): ELU (alpha=1.0)
      (2): Linear (20 -> 20)
      (3): ELU (alpha=1.0)
      (4): Linear (20 -> 15)
    )
  )
  (sampler): ReinforceSampler (
  )
  (pred_net): PredNet (
    (enc_layer1): Linear (10 -> 20)
    (enc_layer2): Linear (20 -> 20)
    (enc_layer3): Linear (20 -> 50)
    (enc_net): Sequential (
      (0): Linear (10 -> 20)
      (1): ReLU ()
      (2): Linear (20 -> 20)
      (3): ReLU ()
      (4): Linear (20 -> 50)
    )
    (pred_layer1): Linear (50 -> 20)
    (pred_layer2): Linear (20 -> 20)
    (pred_layer3): Linear (20 -> 3)
    (pred_net): Sequential (
      (0): Linear (50 -> 20)
      (1): ReLU ()
      (2): Linear (20 -> 20)
      (3): ReLU ()
      (4): Linear (20 -> 3)
    )
  )
  (criterion): MSELoss (
  )
  (activation): Sigmoid ()
)>

In [ ]:
t