In [41]:
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd import StochasticFunction
from torch.autograd import Variable

import dpp_nets.dpp as dpp
from dpp_nets.my_torch.utilities import omit_slice
from dpp_nets.my_torch.utilities import orthogonalize


class AllInOne(StochasticFunction):
    
    def forward(self, kernel):
        self.dtype = kernel.type()
        vecs, vals, _ = torch.svd(kernel)
        vals.pow_(2)

        # Sometimes orthogonalization fails (i.e. deletes vectors)
        # In that case just retry!
        while True:
            try:
                # Set-up
                n = vecs.size(0)
                n_vals = vals.size(0)

                # Sample a set size
                index = (vals / (vals + 1)).bernoulli().byte()
                k = torch.sum(index)

                # Check for empty set
                if not k:
                    subset = vals.new().resize_(n).copy_(torch.zeros(n))
                    self.save_for_backward(kernel, subset) 
                    return subset
                
                # Check for full set
                if k == n:
                    subset =  vals.new().resize_(n).copy_(torch.ones(n))
                    self.save_for_backward(kernel, subset) 
                    return subset

                # Sample a subset
                V = vecs[index.expand_as(vecs)].view(n, -1)
                subset = vals.new().resize_(n).copy_(torch.zeros(n))
                
                while subset.sum() < k:

                    # Sample an item
                    probs = V.pow(2).sum(1).t()
                    item = probs.multinomial(1)[0,0]
                    subset[item] = 1
                    
                    # CHeck if we got k items now
                    if subset.sum() == k:
                        break

                    # Choose eigenvector to eliminate
                    j = V[item, ].abs().sign().unsqueeze(1).t().multinomial(1)[0,0]
                    Vj = V[:, j]
                    
                    # Update vector basis
                    V = omit_slice(V,1,j)
                    V.sub_(Vj.ger(V[item, :] / Vj[item]))

                    # Orthogonalize vector basis
                    V, _ = torch.qr(V)

            except RuntimeError:
                print("RuntimeError")
                continue
            break
        
        return subset
        
    def backward(self, reward):
        #TODO: Need to check this!
        # Checked it! Looks good.

        # Set-up
        kernel, subset = self.kernel, self.subset
        dtype = self.dtype

        n, kernel_dim = kernel.size()
        subset_sum = subset.long().sum()   
        grad_kernel = torch.zeros(kernel.size()).type(dtype)

        if subset_sum:
            # auxillary
            P = torch.eye(n).masked_select(subset.expand(n,n).t().byte()).view(subset_sum, -1).type(dtype)
            subembd = P.mm(kernel)
            submatrix = subembd.mm(subembd.t())
            submatinv = torch.inverse(submatrix)
            subgrad = 2 * submatinv.mm(subembd)
            subgrad = P.t().mm(subgrad)
            grad_kernel.add_(subgrad)
        
        # Gradient from whole L matrix
        K = kernel.t().mm(kernel) # not L!
        I_k = torch.eye(kernel_dim).type(dtype)
        I = torch.eye(n).type(dtype)
        inv = torch.inverse(I_k + K)
        B = I - kernel.mm(inv).mm(kernel.t())
        grad_from_full = 2 * B.mm(kernel)
        grad_kernel.sub_(grad_from_full)

        grad_kernel.mul_(reward)

In [42]:
for i in range(10000):
    A = Variable(torch.randn(10,10), requires_grad=True)
    AllInOne()(A)

In [15]:
print(i)


64

In [ ]: