In [41]:
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd import StochasticFunction
from torch.autograd import Variable
import dpp_nets.dpp as dpp
from dpp_nets.my_torch.utilities import omit_slice
from dpp_nets.my_torch.utilities import orthogonalize
class AllInOne(StochasticFunction):
def forward(self, kernel):
self.dtype = kernel.type()
vecs, vals, _ = torch.svd(kernel)
vals.pow_(2)
# Sometimes orthogonalization fails (i.e. deletes vectors)
# In that case just retry!
while True:
try:
# Set-up
n = vecs.size(0)
n_vals = vals.size(0)
# Sample a set size
index = (vals / (vals + 1)).bernoulli().byte()
k = torch.sum(index)
# Check for empty set
if not k:
subset = vals.new().resize_(n).copy_(torch.zeros(n))
self.save_for_backward(kernel, subset)
return subset
# Check for full set
if k == n:
subset = vals.new().resize_(n).copy_(torch.ones(n))
self.save_for_backward(kernel, subset)
return subset
# Sample a subset
V = vecs[index.expand_as(vecs)].view(n, -1)
subset = vals.new().resize_(n).copy_(torch.zeros(n))
while subset.sum() < k:
# Sample an item
probs = V.pow(2).sum(1).t()
item = probs.multinomial(1)[0,0]
subset[item] = 1
# CHeck if we got k items now
if subset.sum() == k:
break
# Choose eigenvector to eliminate
j = V[item, ].abs().sign().unsqueeze(1).t().multinomial(1)[0,0]
Vj = V[:, j]
# Update vector basis
V = omit_slice(V,1,j)
V.sub_(Vj.ger(V[item, :] / Vj[item]))
# Orthogonalize vector basis
V, _ = torch.qr(V)
except RuntimeError:
print("RuntimeError")
continue
break
return subset
def backward(self, reward):
#TODO: Need to check this!
# Checked it! Looks good.
# Set-up
kernel, subset = self.kernel, self.subset
dtype = self.dtype
n, kernel_dim = kernel.size()
subset_sum = subset.long().sum()
grad_kernel = torch.zeros(kernel.size()).type(dtype)
if subset_sum:
# auxillary
P = torch.eye(n).masked_select(subset.expand(n,n).t().byte()).view(subset_sum, -1).type(dtype)
subembd = P.mm(kernel)
submatrix = subembd.mm(subembd.t())
submatinv = torch.inverse(submatrix)
subgrad = 2 * submatinv.mm(subembd)
subgrad = P.t().mm(subgrad)
grad_kernel.add_(subgrad)
# Gradient from whole L matrix
K = kernel.t().mm(kernel) # not L!
I_k = torch.eye(kernel_dim).type(dtype)
I = torch.eye(n).type(dtype)
inv = torch.inverse(I_k + K)
B = I - kernel.mm(inv).mm(kernel.t())
grad_from_full = 2 * B.mm(kernel)
grad_kernel.sub_(grad_from_full)
grad_kernel.mul_(reward)
In [42]:
for i in range(10000):
A = Variable(torch.randn(10,10), requires_grad=True)
AllInOne()(A)
In [15]:
print(i)
In [ ]: