In [30]:
import torch
import torch.nn as nn
from torch.autograd import Variable
from dpp_nets.layers.layers import KernelVar

In [23]:
words = Variable(torch.FloatTensor([[[1,2,3,4],[3,4,5,6],[0,0,0,0]],[[1,2,3,4],[0,0,0,0],[0,0,0,0]]]))

In [26]:
batch_size, max_set_size, embd_dim = words.size()

# Create context
lengths = words.sum(2).abs().sign().sum(1)
context = (words.sum(1) / lengths.expand_as(words.sum(1))).expand_as(words)

# Filter out zero words 
mask = words.data.sum(2).abs().sign().expand_as(words).byte()
words = words.masked_select(Variable(mask)).view(-1, embd_dim)
context = context.masked_select(Variable(mask)).view(-1, embd_dim)

# Concatenate and compute kernel
batch_x = torch.cat([words, context], dim=1)

In [35]:
hidden_dim = 200
kernel_dim = 200
torch.manual_seed(200)
layer1 = nn.Linear(2 * embd_dim, hidden_dim)
layer2 = nn.Linear(hidden_dim, hidden_dim)
layer3 = nn.Linear(hidden_dim, kernel_dim)

net = nn.Sequential(layer1, nn.Tanh(), layer2, nn.Tanh(), layer3)
batch_kernel = net(batch_x)

In [37]:
# Register indices for individual kernels
s_ix = list(lengths.squeeze().cumsum(0).long().data - lengths.squeeze().long().data)
e_ix = list(lengths.squeeze().cumsum(0).long().data)

In [38]:
s_ix


Out[38]:
[0, 2]

In [39]:
e_ix


Out[39]:
[2, 3]

In [ ]: