Create General KLDCriterion

Verify forward pass


In [1]:
require 'BKLDCriterion'
require 'KLDCriterion'
bkld = nn.BKLDCriterion()
kld = nn.KLDCriterion()

In [2]:
mu = torch.randn(3,4)
lv = torch.randn(3,4):pow(2):log()
pmu = torch.zeros(3,4)
plv = torch.zeros(3,4)

In [3]:
print(bkld:forward(mu, lv))
print(kld:forward({pmu, plv}, {mu, lv})) -- make sure you use the right one as "target"!


Out[3]:
10.982358832999	
10.982358832999	

Verify backward pass


In [4]:
mu = torch.randn(3,4)
lv = torch.randn(3,4):pow(2):log()
pmu = torch.randn(3,4)
plv = torch.randn(3,4):pow(2):log()

In [5]:
dpmu, dplv, dmu, dlv = unpack(kld:backward({pmu, plv}, {mu, lv}))

In [6]:
h = 1e-5

In [7]:
for i = 1,mu:size(1) do
    for j = 1,mu:size(2) do
        mu[{i, j}] = mu[{i, j}] + h
        fph = kld:forward({pmu, plv}, {mu, lv})
        mu[{i, j}] = mu[{i, j}] - h - h
        fmp = kld:forward({pmu, plv}, {mu, lv})
        mu[{i, j}] = mu[{i, j}] + h
        print((fph - fmp)/2/h - dmu[{i, j}])
    end
end


Out[7]:
-4.8075159519989e-08	
-3.9645419003254e-08	
-3.7435487487691e-08	
-3.3182473102578e-08	
-3.0266050998762e-08	
6.8491714522168e-09	
4.4403758781009e-08	
-2.1189407561906e-08	
1.5683021103996e-08	
-1.1057379567525e-08	
2.0470082517932e-08	
1.1078324479996e-08	

In [8]:
for i = 1,lv:size(1) do
    for j = 1,lv:size(2) do
        lv[{i, j}] = lv[{i, j}] + h
        fph = kld:forward({pmu, plv}, {mu, lv})
        lv[{i, j}] = lv[{i, j}] - h - h
        fmp = kld:forward({pmu, plv}, {mu, lv})
        lv[{i, j}] = lv[{i, j}] + h
        print((fph - fmp)/2/h - dlv[{i, j}])
    end
end


Out[8]:
6.0015261510449e-09	
3.7806444641575e-09	
1.1994583715147e-08	
-1.7050326817092e-08	
3.8166824367636e-08	
-1.3593235426157e-08	
1.1548187864308e-08	
2.54751284956e-07	
-3.1273485467942e-09	
2.5779371637391e-08	
-7.1380505339835e-08	
1.934774405965e-08	

In [9]:
for i = 1,pmu:size(1) do
    for j = 1,pmu:size(2) do
        pmu[{i, j}] = pmu[{i, j}] + h
        fph = kld:forward({pmu, plv}, {mu, lv})
        pmu[{i, j}] = pmu[{i, j}] - h - h
        fmp = kld:forward({pmu, plv}, {mu, lv})
        pmu[{i, j}] = pmu[{i, j}] + h
        print((fph - fmp)/2/h - dpmu[{i, j}])
    end
end


Out[9]:
4.8075159519989e-08	
3.9645419003254e-08	
3.7435487487691e-08	
3.3182473102578e-08	
3.0266050998762e-08	
-6.8491714522168e-09	
-4.4403758781009e-08	
2.1189407561906e-08	
-1.5683021103996e-08	
1.1057379567525e-08	
-2.0470082517932e-08	
-1.1078324479996e-08	

In [10]:
for i = 1,plv:size(1) do
    for j = 1,plv:size(2) do
        plv[{i, j}] = plv[{i, j}] + h
        fph = kld:forward({pmu, plv}, {mu, lv})
        plv[{i, j}] = plv[{i, j}] - h - h
        fmp = kld:forward({pmu, plv}, {mu, lv})
        plv[{i, j}] = plv[{i, j}] + h
        print((fph - fmp)/2/h - dplv[{i, j}])
    end
end


Out[10]:
2.0744828077568e-08	
2.1236241598555e-08	
1.8659129796816e-09	
4.1080057772147e-08	
3.6271748271588e-08	
1.0632842828429e-08	
-1.0203965977729e-08	
2.0403058442753e-07	
3.047924757027e-08	
-6.4392204568442e-08	
-2.4210677906922e-09	
-2.1025407137554e-08	

Basic matrix stuff


In [ ]:
-- check my implementation is correct
-- create mu's
mu1 = torch.randn(3,1)
mu2 = torch.randn(3,1)

In [ ]:
-- create sigma
sig = torch.diag(torch.randn(3):pow(2))
sig_lin = torch.diag(sig):reshape(3,1)

In [ ]:
-- first approach
(mu1 - mu2):t() * (torch.inverse(sig)) * (mu1 - mu2)

In [ ]:
-- second approach
torch.Tensor(1):fill((mu1 - mu2):pow(2):cdiv(sig_lin):sum())

Test inplace


In [ ]:
x = torch.Tensor(1):fill(3)

In [ ]:
x:add(3):add(x)