In [30]:
import torch
import torch.nn as nn
from torch.autograd import Variable
from dpp_nets.layers.layers import KernelVar
In [23]:
words = Variable(torch.FloatTensor([[[1,2,3,4],[3,4,5,6],[0,0,0,0]],[[1,2,3,4],[0,0,0,0],[0,0,0,0]]]))
In [26]:
batch_size, max_set_size, embd_dim = words.size()
# Create context
lengths = words.sum(2).abs().sign().sum(1)
context = (words.sum(1) / lengths.expand_as(words.sum(1))).expand_as(words)
# Filter out zero words
mask = words.data.sum(2).abs().sign().expand_as(words).byte()
words = words.masked_select(Variable(mask)).view(-1, embd_dim)
context = context.masked_select(Variable(mask)).view(-1, embd_dim)
# Concatenate and compute kernel
batch_x = torch.cat([words, context], dim=1)
In [35]:
hidden_dim = 200
kernel_dim = 200
torch.manual_seed(200)
layer1 = nn.Linear(2 * embd_dim, hidden_dim)
layer2 = nn.Linear(hidden_dim, hidden_dim)
layer3 = nn.Linear(hidden_dim, kernel_dim)
net = nn.Sequential(layer1, nn.Tanh(), layer2, nn.Tanh(), layer3)
batch_kernel = net(batch_x)
In [37]:
# Register indices for individual kernels
s_ix = list(lengths.squeeze().cumsum(0).long().data - lengths.squeeze().long().data)
e_ix = list(lengths.squeeze().cumsum(0).long().data)
In [38]:
s_ix
Out[38]:
In [39]:
e_ix
Out[39]:
In [ ]: